diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index c0258da704be..2bc2222dd24b 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -1033,6 +1033,10 @@
title: DePlot
- local: model_doc/donut
title: Donut
+ - local: model_doc/edgetam
+ title: EdgeTAM
+ - local: model_doc/edgetam_video
+ title: EdgeTamVideo
- local: model_doc/emu3
title: Emu3
- local: model_doc/evolla
diff --git a/docs/source/en/model_doc/edgetam.md b/docs/source/en/model_doc/edgetam.md
new file mode 100644
index 000000000000..780ccb3f70b3
--- /dev/null
+++ b/docs/source/en/model_doc/edgetam.md
@@ -0,0 +1,331 @@
+
+*This model was released on 2025-01-13 and added to Hugging Face Transformers on 2025-09-29.*
+
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 2905a842612e..c721f24a506d 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -108,6 +108,8 @@
from .dots1 import *
from .dpr import *
from .dpt import *
+ from .edgetam import *
+ from .edgetam_video import *
from .efficientloftr import *
from .efficientnet import *
from .electra import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index c40b5a37b02a..f6a12e7cef98 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -127,6 +127,9 @@
("dots1", "Dots1Config"),
("dpr", "DPRConfig"),
("dpt", "DPTConfig"),
+ ("edgetam", "EdgeTamConfig"),
+ ("edgetam_video", "EdgeTamVideoConfig"),
+ ("edgetam_vision_model", "EdgeTamVisionConfig"),
("efficientformer", "EfficientFormerConfig"),
("efficientloftr", "EfficientLoFTRConfig"),
("efficientnet", "EfficientNetConfig"),
@@ -563,6 +566,9 @@
("dots1", "dots1"),
("dpr", "DPR"),
("dpt", "DPT"),
+ ("edgetam", "EdgeTAM"),
+ ("edgetam_video", "EdgeTamVideo"),
+ ("edgetam_vision_model", "EdgeTamVisionModel"),
("efficientformer", "EfficientFormer"),
("efficientloftr", "EfficientLoFTR"),
("efficientnet", "EfficientNet"),
@@ -983,6 +989,7 @@
("qwen3_vl_moe_text", "qwen3_vl_moe"),
("sam_vision_model", "sam"),
("sam2_vision_model", "sam2"),
+ ("edgetam_vision_model", "edgetam"),
("sam2_hiera_det_model", "sam2"),
("sam_hq_vision_model", "sam_hq"),
("llama4_text", "llama4"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index aa16ac3555eb..a272735af207 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -91,6 +91,7 @@
("dinov3_vit", (None, "DINOv3ViTImageProcessorFast")),
("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")),
("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")),
+ ("edgetam", (None, "Sam2ImageProcessorFast")),
("efficientformer", ("EfficientFormerImageProcessor", None)),
("efficientloftr", ("EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorFast")),
("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 297d4890d131..298834bebe93 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -131,6 +131,9 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("dots1", "Dots1Model"),
("dpr", "DPRQuestionEncoder"),
("dpt", "DPTModel"),
+ ("edgetam", "EdgeTamModel"),
+ ("edgetam_video", "EdgeTamVideoModel"),
+ ("edgetam_vision_model", "EdgeTamVisionModel"),
("efficientformer", "EfficientFormerModel"),
("efficientloftr", "EfficientLoFTRModel"),
("efficientnet", "EfficientNetModel"),
@@ -1709,6 +1712,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
+ ("edgetam", "EdgeTamModel"),
+ ("edgetam_video", "EdgeTamModel"),
("sam", "SamModel"),
("sam2", "Sam2Model"),
("sam2_video", "Sam2Model"),
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index 2b1ca09bb8df..11862a5896b9 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -66,6 +66,7 @@
("deepseek_vl", "DeepseekVLProcessor"),
("deepseek_vl_hybrid", "DeepseekVLHybridProcessor"),
("dia", "DiaProcessor"),
+ ("edgetam", "Sam2Processor"),
("emu3", "Emu3Processor"),
("evolla", "EvollaProcessor"),
("flava", "FlavaProcessor"),
diff --git a/src/transformers/models/edgetam/__init__.py b/src/transformers/models/edgetam/__init__.py
new file mode 100644
index 000000000000..d9c1a55fc5bc
--- /dev/null
+++ b/src/transformers/models/edgetam/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_edgetam import *
+ from .modeling_edgetam import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/edgetam/configuration_edgetam.py b/src/transformers/models/edgetam/configuration_edgetam.py
new file mode 100644
index 000000000000..07ccee36e932
--- /dev/null
+++ b/src/transformers/models/edgetam/configuration_edgetam.py
@@ -0,0 +1,332 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/edgetam/modular_edgetam.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_edgetam.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+class EdgeTamVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`EdgeTamVisionModel`]. It is used to instantiate a SAM
+ vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
+ defaults will yield a similar configuration to that of SAM 2.1 Hiera-tiny
+ [facebook/EdgeTAM](https://huggingface.co/facebook/EdgeTAM) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ backbone_config (`Union[dict, "PretrainedConfig"]`, *optional*):
+ Configuration for the vision backbone. This is used to instantiate the backbone using
+ `AutoModel.from_config`.
+ backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`):
+ The list of channel dimensions for the backbone.
+ backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[256, 256], [128, 128], [64, 64]]`):
+ The spatial sizes of the feature maps from the backbone.
+ fpn_hidden_size (`int`, *optional*, defaults to 256):
+ The hidden dimension of the FPN.
+ fpn_kernel_size (`int`, *optional*, defaults to 1):
+ The kernel size for the convolutions in the neck.
+ fpn_stride (`int`, *optional*, defaults to 1):
+ The stride for the convolutions in the neck.
+ fpn_padding (`int`, *optional*, defaults to 0):
+ The padding for the convolutions in the neck.
+ fpn_top_down_levels (`List[int]`, *optional*, defaults to `[2, 3]`):
+ The levels for the top-down FPN connections.
+ num_feature_levels (`int`, *optional*, defaults to 3):
+ The number of feature levels from the FPN to use.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the neck.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon for the layer normalization.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ """
+
+ base_config_key = "vision_config"
+ model_type = "edgetam_vision_model"
+ sub_configs = {
+ "backbone_config": AutoConfig,
+ }
+
+ def __init__(
+ self,
+ backbone_config=None,
+ backbone_channel_list=None,
+ backbone_feature_sizes=None,
+ fpn_hidden_size=256,
+ fpn_kernel_size=1,
+ fpn_stride=1,
+ fpn_padding=0,
+ fpn_top_down_levels=None,
+ num_feature_levels=3,
+ hidden_act="gelu",
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ backbone_channel_list = [384, 192, 96, 48] if backbone_channel_list is None else backbone_channel_list
+ backbone_feature_sizes = (
+ [[256, 256], [128, 128], [64, 64]] if backbone_feature_sizes is None else backbone_feature_sizes
+ )
+ fpn_top_down_levels = [2, 3] if fpn_top_down_levels is None else fpn_top_down_levels
+
+ if isinstance(backbone_config, dict):
+ backbone_config["model_type"] = backbone_config.get("model_type", "timm_wrapper")
+ backbone_config = CONFIG_MAPPING[backbone_config["model_type"]](**backbone_config)
+ elif isinstance(backbone_config, AutoConfig):
+ backbone_config = backbone_config
+ elif backbone_config is None:
+ backbone_config = AutoConfig.from_pretrained(
+ "timm/repvit_m1.dist_in1k",
+ model_args={"in_chans": 3, "features_only": True, "out_indices": [0, 1, 2, 3]},
+ )
+
+ self.backbone_config = backbone_config
+
+ # Neck
+ self.backbone_channel_list = backbone_channel_list
+ self.backbone_feature_sizes = backbone_feature_sizes
+ self.fpn_hidden_size = fpn_hidden_size
+ self.fpn_kernel_size = fpn_kernel_size
+ self.fpn_stride = fpn_stride
+ self.fpn_padding = fpn_padding
+ self.fpn_top_down_levels = fpn_top_down_levels
+ self.num_feature_levels = num_feature_levels
+
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+
+
+class EdgeTamPromptEncoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`EdgeTamPromptEncoder`]. The [`EdgeTamPromptEncoder`]
+ module is used to encode the input 2D points and bounding boxes.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the hidden states.
+ image_size (`int`, *optional*, defaults to 1024):
+ The expected output resolution of the image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ mask_input_channels (`int`, *optional*, defaults to 16):
+ The number of channels to be fed to the `MaskDecoder` module.
+ num_point_embeddings (`int`, *optional*, defaults to 4):
+ The number of point embeddings to be used.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the encoder and pooler.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ scale (`float`, *optional*, defaults to 1):
+ The scale factor for the prompt encoder.
+ """
+
+ base_config_key = "prompt_encoder_config"
+
+ def __init__(
+ self,
+ hidden_size=256,
+ image_size=1024,
+ patch_size=16,
+ mask_input_channels=16,
+ num_point_embeddings=4,
+ hidden_act="gelu",
+ layer_norm_eps=1e-6,
+ scale=1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.mask_input_channels = mask_input_channels
+ self.num_point_embeddings = num_point_embeddings
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.scale = scale
+
+
+class EdgeTamMaskDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`EdgeTamMaskDecoder`]. It is used to instantiate a EDGETAM
+ memory encoder according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the hidden states.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the EDGETAM mask decoder.
+ mlp_dim (`int`, *optional*, defaults to 2048):
+ The dimension of the MLP in the two-way transformer.
+ num_hidden_layers (`int`, *optional*, defaults to 2):
+ The number of hidden layers in the two-way transformer.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ The number of attention heads in the two-way transformer.
+ attention_downsample_rate (`int`, *optional*, defaults to 2):
+ The downsample rate for the attention layers.
+ num_multimask_outputs (`int`, *optional*, defaults to 3):
+ The number of multimask outputs.
+ iou_head_depth (`int`, *optional*, defaults to 3):
+ The depth of the IoU head.
+ iou_head_hidden_dim (`int`, *optional*, defaults to 256):
+ The hidden dimension of the IoU head.
+ dynamic_multimask_via_stability (`bool`, *optional*, defaults to `True`):
+ Whether to use dynamic multimask via stability.
+ dynamic_multimask_stability_delta (`float`, *optional*, defaults to 0.05):
+ The stability delta for the dynamic multimask.
+ dynamic_multimask_stability_thresh (`float`, *optional*, defaults to 0.98):
+ The stability threshold for the dynamic multimask.
+
+ """
+
+ base_config_key = "mask_decoder_config"
+
+ def __init__(
+ self,
+ hidden_size=256,
+ hidden_act="gelu",
+ mlp_dim=2048,
+ num_hidden_layers=2,
+ num_attention_heads=8,
+ attention_downsample_rate=2,
+ num_multimask_outputs=3,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ dynamic_multimask_via_stability=True,
+ dynamic_multimask_stability_delta=0.05,
+ dynamic_multimask_stability_thresh=0.98,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_multimask_outputs = num_multimask_outputs
+ self.hidden_act = hidden_act
+ self.iou_head_depth = iou_head_depth
+ self.iou_head_hidden_dim = iou_head_hidden_dim
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
+
+ # TwoWayTransformer configuration
+ self.num_hidden_layers = num_hidden_layers
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.mlp_dim = mlp_dim
+ self.attention_downsample_rate = attention_downsample_rate
+
+
+class EdgeTamConfig(PretrainedConfig):
+ r"""
+ [`EdgeTamConfig`] is the configuration class to store the configuration of a [`EdgeTamModel`]. It is used to instantiate a
+ EDGETAM model according to the specified arguments, defining the memory attention, memory encoder, and image encoder
+ configs. Instantiating a configuration defaults will yield a similar configuration to that of the SAM 2.1 Hiera-tiny
+ [facebook/edgetam.1-hiera-tiny](https://huggingface.co/facebook/edgetam.1-hiera-tiny) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (Union[`dict`, `EdgeTamVisionConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`EdgeTamVisionConfig`].
+ prompt_encoder_config (Union[`dict`, `EdgeTamPromptEncoderConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`EdgeTamPromptEncoderConfig`].
+ mask_decoder_config (Union[`dict`, `EdgeTamMaskDecoderConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`EdgeTamMaskDecoderConfig`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ Standard deviation for parameter initialization.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... EdgeTamVisionConfig,
+ ... EdgeTamPromptEncoderConfig,
+ ... EdgeTamMaskDecoderConfig,
+ ... EdgeTamModel,
+ ... )
+
+ >>> # Initializing a EdgeTamConfig with `"facebook/edgetam.1_hiera_tiny"` style configuration
+ >>> configuration = EdgeTamconfig()
+
+ >>> # Initializing a EdgeTamModel (with random weights) from the `"facebook/edgetam.1_hiera_tiny"` style configuration
+ >>> model = EdgeTamModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+
+ >>> # We can also initialize a EdgeTamConfig from a EdgeTamVisionConfig, EdgeTamPromptEncoderConfig, and EdgeTamMaskDecoderConfig
+
+ >>> # Initializing EDGETAM vision encoder, memory attention, and memory encoder configurations
+ >>> vision_config = EdgeTamVisionConfig()
+ >>> prompt_encoder_config = EdgeTamPromptEncoderConfig()
+ >>> mask_decoder_config = EdgeTamMaskDecoderConfig()
+
+ >>> config = EdgeTamConfig(vision_config, prompt_encoder_config, mask_decoder_config)
+ ```"""
+
+ model_type = "edgetam"
+ sub_configs = {
+ "vision_config": AutoConfig,
+ "prompt_encoder_config": EdgeTamPromptEncoderConfig,
+ "mask_decoder_config": EdgeTamMaskDecoderConfig,
+ }
+
+ def __init__(
+ self,
+ vision_config=None,
+ prompt_encoder_config=None,
+ mask_decoder_config=None,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ vision_config = vision_config if vision_config is not None else {}
+ prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
+ mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {}
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "edgetam_vision_model")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ if isinstance(prompt_encoder_config, EdgeTamPromptEncoderConfig):
+ prompt_encoder_config = prompt_encoder_config.to_dict()
+ if isinstance(mask_decoder_config, EdgeTamMaskDecoderConfig):
+ mask_decoder_config = mask_decoder_config.to_dict()
+
+ self.vision_config = vision_config
+ self.prompt_encoder_config = EdgeTamPromptEncoderConfig(**prompt_encoder_config)
+ self.mask_decoder_config = EdgeTamMaskDecoderConfig(**mask_decoder_config)
+
+ self.initializer_range = initializer_range
+
+
+__all__ = ["EdgeTamConfig", "EdgeTamVisionConfig", "EdgeTamPromptEncoderConfig", "EdgeTamMaskDecoderConfig"]
diff --git a/src/transformers/models/edgetam/convert_edgetam_to_hf.py b/src/transformers/models/edgetam/convert_edgetam_to_hf.py
new file mode 100644
index 000000000000..382bc1559ec4
--- /dev/null
+++ b/src/transformers/models/edgetam/convert_edgetam_to_hf.py
@@ -0,0 +1,280 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Convert SAM checkpoints from the original repository.
+
+URL: https://github.com/facebookresearch/segment-anything-2.
+"""
+
+import argparse
+import re
+
+import numpy as np
+import requests
+import torch
+from huggingface_hub import hf_hub_download
+from PIL import Image
+
+from transformers import (
+ EdgeTamConfig,
+ EdgeTamMaskDecoderConfig,
+ EdgeTamModel,
+ EdgeTamPromptEncoderConfig,
+ EdgeTamVisionConfig,
+ Sam2ImageProcessorFast,
+ Sam2Processor,
+ TimmWrapperConfig,
+)
+
+
+def get_config(model_name):
+ backbone_config = TimmWrapperConfig.from_pretrained(
+ "timm/repvit_m1.dist_in1k",
+ model_args={"in_chans": 3, "features_only": True, "out_indices": (0, 1, 2, 3)},
+ )
+ vision_config = EdgeTamVisionConfig(backbone_config=backbone_config)
+
+ prompt_encoder_config = EdgeTamPromptEncoderConfig()
+ mask_decoder_config = EdgeTamMaskDecoderConfig()
+ enable_temporal_pos_encoding_for_object_pointers = False
+ project_temporal_pos_encoding_in_object_pointers = False
+ enable_occlusion_spatial_embedding = False
+
+ config = EdgeTamConfig(
+ vision_config=vision_config,
+ prompt_encoder_config=prompt_encoder_config,
+ mask_decoder_config=mask_decoder_config,
+ enable_temporal_pos_encoding_for_object_pointers=enable_temporal_pos_encoding_for_object_pointers,
+ project_temporal_pos_encoding_in_object_pointers=project_temporal_pos_encoding_in_object_pointers,
+ enable_occlusion_spatial_embedding=enable_occlusion_spatial_embedding,
+ )
+
+ return config
+
+
+KEYS_TO_MODIFY_MAPPING = {
+ "iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
+ "iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
+ "iou_prediction_head.layers.2": "iou_prediction_head.proj_out",
+ "mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1",
+ "mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm",
+ "mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2",
+ "mask_downscaling.0": "mask_embed.conv1",
+ "mask_downscaling.1": "mask_embed.layer_norm1",
+ "mask_downscaling.3": "mask_embed.conv2",
+ "mask_downscaling.4": "mask_embed.layer_norm2",
+ "mask_downscaling.6": "mask_embed.conv3",
+ "dwconv": "depthwise_conv",
+ "pwconv": "pointwise_conv",
+ "fuser": "memory_fuser",
+ "point_embeddings": "point_embed",
+ "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding",
+ "obj_ptr_tpos_proj": "temporal_positional_encoding_projection_layer",
+ "no_obj_embed_spatial": "occlusion_spatial_embedding_parameter",
+ "sam_prompt_encoder": "prompt_encoder",
+ "sam_mask_decoder": "mask_decoder",
+ "maskmem_tpos_enc": "memory_temporal_positional_encoding",
+ "gamma": "scale",
+ "image_encoder.neck": "vision_encoder.neck",
+ "image_encoder": "vision_encoder.backbone",
+ "neck.0": "neck.conv1",
+ "neck.1": "neck.layer_norm1",
+ "neck.2": "neck.conv2",
+ "neck.3": "neck.layer_norm2",
+ "pix_feat_proj": "feature_projection",
+ "patch_embed.proj": "patch_embed.projection",
+ "no_mem_embed": "no_memory_embedding",
+ "no_mem_pos_enc": "no_memory_positional_encoding",
+ "obj_ptr": "object_pointer",
+ ".norm": ".layer_norm",
+ "trunk.": "",
+ "out_proj": "o_proj",
+ "body.": "timm_model.",
+ "ff.0": "feed_forward.layer_norm",
+ "ff.1": "feed_forward.linear1",
+ "ff.3": "feed_forward.linear2",
+}
+
+
+def replace_keys(state_dict):
+ model_state_dict = {}
+ output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*"
+ output_mask_decoder_mlps_pattern = r"mask_decoder.transformer.layers.(\d+).mlp.layers.(\d+).*"
+ output_mask_decoder_score_head_pattern = r"mask_decoder.pred_obj_score_head.layers.(\d+).*"
+ output_vision_encoder_mlps_pattern = r"vision_encoder.backbone.blocks.(\d+).mlp.layers.(\d+).*"
+ output_vision_encoder_neck_pattern = r"vision_encoder.neck.convs.(\d+).conv"
+ output_memory_encoder_projection_pattern = r"memory_encoder.o_proj.*"
+ output_object_pointer_proj_pattern = r"object_pointer_proj.layers.(\d+).*"
+ for key, value in state_dict.items():
+ for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
+ if key_to_modify in key:
+ key = key.replace(key_to_modify, new_key)
+
+ # vision_encoder.blocks.0.mlp.layers.1.weight -> vision_encoder.blocks.0.mlp.proj_out.weight
+ if re.match(output_vision_encoder_mlps_pattern, key):
+ layer_nb = int(re.match(output_vision_encoder_mlps_pattern, key).group(2))
+ if layer_nb == 0:
+ key = key.replace("layers.0", "proj_in")
+ elif layer_nb == 1:
+ key = key.replace("layers.1", "proj_out")
+
+ # mask_decoder.transformer.layers.0.mlp.layers.1.weight -> mask_decoder.transformer.layers.1.mlp.proj_out.weight
+ if re.match(output_mask_decoder_mlps_pattern, key):
+ layer_nb = int(re.match(output_mask_decoder_mlps_pattern, key).group(2))
+ if layer_nb == 0:
+ key = key.replace("mlp.layers.0", "mlp.proj_in")
+ elif layer_nb == 1:
+ key = key.replace("mlp.layers.1", "mlp.proj_out")
+
+ # mask_decoder.pred_obj_score_head.layers.1.weight -> mask_decoder.pred_obj_score_head.proj_in.weight
+ if re.match(output_mask_decoder_score_head_pattern, key):
+ layer_nb = int(re.match(output_mask_decoder_score_head_pattern, key).group(1))
+ if layer_nb == 0:
+ key = key.replace("layers.0", "proj_in")
+ elif layer_nb == 1:
+ key = key.replace("layers.1", "layers.0")
+ elif layer_nb == 2:
+ key = key.replace("layers.2", "proj_out")
+
+ if re.match(output_hypernetworks_mlps_pattern, key):
+ layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2))
+ if layer_nb == 0:
+ key = key.replace("layers.0", "proj_in")
+ elif layer_nb == 1:
+ key = key.replace("layers.1", "layers.0")
+ elif layer_nb == 2:
+ key = key.replace("layers.2", "proj_out")
+
+ # vision_encoder.neck.convs.1.conv.bias -> vision_encoder.neck.convs.1.bias
+ if re.match(output_vision_encoder_neck_pattern, key):
+ key = key.replace(".conv.", ".")
+
+ # memory_encoder.o_proj.weight -> memory_encoder.projection.weight
+ if re.match(output_memory_encoder_projection_pattern, key):
+ key = key.replace(".o_proj.", ".projection.")
+
+ if re.match(output_object_pointer_proj_pattern, key):
+ layer_nb = int(re.match(output_object_pointer_proj_pattern, key).group(1))
+ if layer_nb == 0:
+ key = key.replace("layers.0", "proj_in")
+ elif layer_nb == 1:
+ key = key.replace("layers.1", "layers.0")
+ elif layer_nb == 2:
+ key = key.replace("layers.2", "proj_out")
+
+ key = key.replace("layers.2", "proj_out")
+
+ model_state_dict[key] = value
+
+ model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[
+ "prompt_encoder.shared_embedding.positional_embedding"
+ ]
+ model_state_dict["prompt_encoder.point_embed.weight"] = torch.cat(
+ [model_state_dict.pop(f"prompt_encoder.point_embed.{i}.weight") for i in range(4)],
+ dim=0,
+ )
+
+ return model_state_dict
+
+
+def convert_edgetam_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub, run_sanity_check):
+ config = get_config(model_name)
+
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
+ state_dict = replace_keys(state_dict)
+
+ image_processor = Sam2ImageProcessorFast()
+ processor = Sam2Processor(image_processor=image_processor)
+ hf_model = EdgeTamModel(config)
+ hf_model.eval()
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
+ hf_model = hf_model.to(device)
+ for pattern in EdgeTamModel._keys_to_ignore_on_load_unexpected:
+ unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None]
+ if missing_keys or unexpected_keys:
+ print("Missing keys:", missing_keys)
+ print("Unexpected keys:", unexpected_keys)
+ raise ValueError("Missing or unexpected keys in the state dict")
+
+ if run_sanity_check:
+ img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+
+ input_points = [[[[1000, 600]]]]
+ input_labels = [[[1]]]
+
+ inputs = processor(
+ images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
+ ).to(device)
+
+ with torch.no_grad():
+ output = hf_model(**inputs)
+ scores = output.iou_scores.squeeze()
+
+ assert torch.allclose(scores, torch.tensor([0.0356, 0.2141, 0.9707]).cuda(), atol=1e-3)
+
+ if pytorch_dump_folder is not None:
+ processor.save_pretrained(pytorch_dump_folder)
+ hf_model.save_pretrained(pytorch_dump_folder)
+
+ if push_to_hub:
+ repo_id = f"yonigozlan/{pytorch_dump_folder.split('/')[-1]}"
+ processor.push_to_hub(repo_id)
+ hf_model.push_to_hub(repo_id)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ choices = ["EdgeTAM"]
+ parser.add_argument(
+ "--model_name",
+ default="EdgeTAM",
+ choices=choices,
+ type=str,
+ help="Name of the original model to convert",
+ )
+ parser.add_argument(
+ "--checkpoint_path",
+ type=str,
+ required=False,
+ help="Path to the original checkpoint",
+ )
+ parser.add_argument("--pytorch_dump_folder_path", default="", type=str, help="Path to the output PyTorch model.")
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether to push the model and processor to the hub after converting",
+ )
+ parser.add_argument(
+ "--run_sanity_check",
+ action="store_true",
+ help="Whether to run the sanity check after converting",
+ )
+
+ args = parser.parse_args()
+
+ hf_model_name = args.model_name.replace("_", "-")
+ checkpoint_path = (
+ hf_hub_download(f"facebook/{hf_model_name}", f"{args.model_name.lower()}.pt")
+ if args.checkpoint_path is None
+ else args.checkpoint_path
+ )
+
+ convert_edgetam_checkpoint(
+ args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub, args.run_sanity_check
+ )
diff --git a/src/transformers/models/edgetam/modeling_edgetam.py b/src/transformers/models/edgetam/modeling_edgetam.py
new file mode 100644
index 000000000000..d7e3ee6009cf
--- /dev/null
+++ b/src/transformers/models/edgetam/modeling_edgetam.py
@@ -0,0 +1,1252 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/edgetam/modular_edgetam.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_edgetam.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import compile_compatible_method_lru_cache
+from ...utils import ModelOutput, auto_docstring
+from ..auto import AutoModel
+from .configuration_edgetam import (
+ EdgeTamConfig,
+ EdgeTamMaskDecoderConfig,
+ EdgeTamPromptEncoderConfig,
+ EdgeTamVisionConfig,
+)
+
+
+# fix this in modular
+if True:
+ from transformers.models.timm_wrapper.modeling_timm_wrapper import TimmWrapperModel
+
+
+class EdgeTamLayerNorm(nn.LayerNorm):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
+ super().__init__(normalized_shape, eps=eps, **kwargs)
+ if data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {data_format}")
+ self.data_format = data_format
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
+ """
+ if self.data_format == "channels_first":
+ features = features.permute(0, 2, 3, 1)
+ features = super().forward(features)
+ features = features.permute(0, 3, 1, 2)
+ else:
+ features = super().forward(features)
+ return features
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
+class EdgeTamVisionEncoderOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ fpn_hidden_states (`tuple(torch.FloatTensor)`):
+ Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
+ `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
+ fpn_position_encoding (`tuple(torch.FloatTensor)`):
+ Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
+ `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
+ model at the output of each stage.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ fpn_hidden_states: Optional[torch.FloatTensor] = None
+ fpn_position_encoding: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class EdgeTamAttention(nn.Module):
+ """
+ EDGETAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
+ values.
+ """
+
+ def __init__(self, config, downsample_rate=None):
+ super().__init__()
+ downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.internal_dim = config.hidden_size // downsample_rate
+ self.num_attention_heads = config.num_attention_heads
+ self.head_dim = self.internal_dim // config.num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_similarity: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # Input projections
+ batch_size, point_batch_size = query.shape[:2]
+ new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
+
+ query = self.q_proj(query).view(*new_shape).transpose(1, 2)
+ key = self.k_proj(key).view(*new_shape).transpose(1, 2)
+ value = self.v_proj(value).view(*new_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=attention_similarity,
+ dropout=0.0,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(
+ batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
+ ).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class EdgeTamTwoWayAttentionBlock(nn.Module):
+ def __init__(self, config: EdgeTamMaskDecoderConfig, skip_first_layer_pe: bool = False):
+ """
+ A transformer block with four layers:
+ (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
+ sparse inputs (4) cross attention of dense inputs -> sparse inputs
+
+ Arguments:
+ config (`EdgeTamMaskDecoderConfig`):
+ The configuration file used to instantiate the block
+ attention_downsample_rate (*optionalk*, int, defaults to 2):
+ The downsample ratio of the block used to reduce the inner dim of the attention.
+ skip_first_layer_pe (*optional*, bool, defaults to `False`):
+ Whether or not to skip the addition of the query_point_embedding on the first layer.
+ """
+ super().__init__()
+ self.self_attn = EdgeTamAttention(config, downsample_rate=1)
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size)
+
+ self.cross_attn_token_to_image = EdgeTamAttention(config)
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size)
+
+ self.mlp = EdgeTamFeedForward(
+ config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
+ )
+ self.layer_norm3 = nn.LayerNorm(config.hidden_size)
+
+ self.layer_norm4 = nn.LayerNorm(config.hidden_size)
+ self.cross_attn_image_to_token = EdgeTamAttention(config)
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self,
+ queries: Tensor,
+ keys: Tensor,
+ query_point_embedding: Tensor,
+ key_point_embedding: Tensor,
+ attention_similarity: Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries, _ = self.self_attn(query=queries, key=queries, value=queries)
+ else:
+ query = queries + query_point_embedding
+ attn_out, _ = self.self_attn(query=query, key=query, value=queries)
+ queries = queries + attn_out
+ queries = self.layer_norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ query = queries + query_point_embedding
+ key = keys + key_point_embedding
+
+ attn_out, _ = self.cross_attn_token_to_image(
+ query=query, key=key, value=keys, attention_similarity=attention_similarity
+ )
+ queries = queries + attn_out
+
+ queries = self.layer_norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.layer_norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ query = queries + query_point_embedding
+ key = keys + key_point_embedding
+
+ attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
+ keys = keys + attn_out
+
+ keys = self.layer_norm4(keys)
+ return queries, keys, attn_out
+
+
+class EdgeTamFeedForward(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ activation: str = "relu",
+ sigmoid_output: bool = False,
+ ):
+ super().__init__()
+ self.num_layers = num_layers
+ self.activation = ACT2FN[activation]
+ self.proj_in = nn.Linear(input_dim, hidden_dim)
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
+ self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ for layer in self.layers:
+ hidden_states = self.activation(layer(hidden_states))
+
+ hidden_states = self.proj_out(hidden_states)
+ if self.sigmoid_output:
+ hidden_states = F.sigmoid(hidden_states)
+ return hidden_states
+
+
+@auto_docstring
+class EdgeTamPreTrainedModel(PreTrainedModel):
+ config_class = EdgeTamConfig
+ base_model_prefix = "edgetam"
+ main_input_name = "pixel_values"
+ _supports_sdpa = True
+ _supports_flash_attn_2 = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, (nn.LayerNorm, EdgeTamLayerNorm)):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ if isinstance(module, EdgeTamModel):
+ if module.no_memory_embedding is not None:
+ module.no_memory_embedding.data.zero_()
+
+
+# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
+class EdgeTamSinePositionEmbedding(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
+ need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
+ ):
+ super().__init__()
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.scale = 2 * math.pi if scale is None else scale
+
+ @compile_compatible_method_lru_cache(maxsize=1)
+ def forward(
+ self,
+ shape: torch.Size,
+ device: Union[torch.device, str],
+ dtype: torch.dtype,
+ mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ if mask is None:
+ mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
+ not_mask = (~mask).to(dtype)
+ y_embed = not_mask.cumsum(1)
+ x_embed = not_mask.cumsum(2)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class EdgeTamVisionNeck(nn.Module):
+ def __init__(self, config: EdgeTamVisionConfig):
+ super().__init__()
+ self.config = config
+
+ self.position_encoding = EdgeTamSinePositionEmbedding(
+ num_pos_feats=config.fpn_hidden_size // 2, normalize=True
+ )
+ self.convs = nn.ModuleList()
+ for in_channels in config.backbone_channel_list:
+ self.convs.append(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=config.fpn_hidden_size,
+ kernel_size=config.fpn_kernel_size,
+ stride=config.fpn_stride,
+ padding=config.fpn_padding,
+ ),
+ )
+ self.fpn_top_down_levels = config.fpn_top_down_levels
+
+ def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
+ fpn_hidden_states = ()
+ fpn_position_encoding = ()
+
+ # forward in top-down order (from low to high resolution)
+ n = len(self.convs) - 1
+ for i in range(n, -1, -1):
+ lateral_features = hidden_states[i].permute(0, 3, 1, 2)
+ lateral_features = self.convs[n - i](lateral_features)
+ if i not in self.fpn_top_down_levels or i == n:
+ prev_features = lateral_features
+ else:
+ top_down_features = F.interpolate(
+ prev_features.to(dtype=torch.float32),
+ scale_factor=2.0,
+ mode="nearest",
+ align_corners=None,
+ antialias=False,
+ ).to(lateral_features.dtype)
+ prev_features = lateral_features + top_down_features
+
+ prev_position_encoding = self.position_encoding(
+ prev_features.shape, prev_features.device, prev_features.dtype
+ ).to(prev_features.dtype)
+
+ fpn_hidden_states += (prev_features,)
+ fpn_position_encoding += (prev_position_encoding,)
+
+ return fpn_hidden_states, fpn_position_encoding
+
+
+@auto_docstring(
+ custom_intro="""
+ The vision model from EdgeTAM without any head or projection on top.
+ """
+)
+class EdgeTamVisionModel(EdgeTamPreTrainedModel):
+ config_class = EdgeTamVisionConfig
+ main_input_name = "pixel_values"
+ _can_record_outputs = {"hidden_states": TimmWrapperModel, "attentions": TimmWrapperModel}
+
+ def __init__(self, config: EdgeTamVisionConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.backbone = AutoModel.from_config(config.backbone_config)
+
+ self.neck = EdgeTamVisionNeck(config)
+ self.num_feature_levels = config.num_feature_levels
+
+ self.post_init()
+
+ @check_model_inputs
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, EdgeTamVisionEncoderOutput]:
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Forward through backbone
+ backbone_output = self.backbone(pixel_values)
+ intermediate_hidden_states = backbone_output.last_hidden_state
+ intermediate_hidden_states = [hidden_state.permute(0, 2, 3, 1) for hidden_state in intermediate_hidden_states]
+
+ fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
+ # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
+ fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
+ fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
+
+ return EdgeTamVisionEncoderOutput(
+ last_hidden_state=intermediate_hidden_states[-1],
+ fpn_hidden_states=fpn_hidden_states,
+ fpn_position_encoding=fpn_position_encoding,
+ )
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the EdgeTam model's output.")
+class EdgeTamImageSegmentationOutput(ModelOutput):
+ r"""
+ iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
+ The Intersection over Union (IoU) scores of the predicted masks.
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
+ The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
+ by the processor to be brought to the original image size.
+ object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
+ Logits for the object score, indicating if an object is present.
+ image_embeddings (`tuple(torch.FloatTensor)`):
+ The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
+ tensor has shape `(batch_size, channels, height, width)`.
+ vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
+ Hidden-states of the vision model at the output of each stage.
+ vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights of the vision model.
+ mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights of the mask decoder.
+ """
+
+ iou_scores: Optional[torch.FloatTensor] = None
+ pred_masks: Optional[torch.FloatTensor] = None
+ object_score_logits: Optional[torch.FloatTensor] = None
+ image_embeddings: tuple[torch.FloatTensor, ...] = None
+ vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+class EdgeTamPositionalEmbedding(nn.Module):
+ def __init__(self, config: EdgeTamPromptEncoderConfig):
+ super().__init__()
+ self.scale = config.scale
+ positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
+ self.register_buffer("positional_embedding", positional_embedding)
+
+ def forward(self, input_coords, input_shape=None):
+ """Positionally encode points that are normalized to [0,1]."""
+ coordinates = input_coords.clone()
+
+ if input_shape is not None:
+ coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
+ coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
+ coordinates.to(torch.float32)
+
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coordinates = 2 * coordinates - 1
+ coordinates = coordinates.to(self.positional_embedding.dtype)
+ coordinates = coordinates @ self.positional_embedding
+ coordinates = 2 * np.pi * coordinates
+ # outputs d_1 x ... x d_n x channel shape
+ return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
+
+
+class EdgeTamMaskEmbedding(nn.Module):
+ def __init__(self, config: EdgeTamPromptEncoderConfig):
+ super().__init__()
+ self.mask_input_channels = config.mask_input_channels // 4
+ self.activation = ACT2FN[config.hidden_act]
+ self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
+ self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
+ self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
+ self.layer_norm1 = EdgeTamLayerNorm(
+ self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
+ )
+ self.layer_norm2 = EdgeTamLayerNorm(
+ self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
+ )
+
+ def forward(self, masks):
+ hidden_states = self.conv1(masks)
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ dense_embeddings = self.conv3(hidden_states)
+ return dense_embeddings
+
+
+class EdgeTamPromptEncoder(nn.Module):
+ def __init__(self, config: EdgeTamPromptEncoderConfig):
+ super().__init__()
+ self.shared_embedding = EdgeTamPositionalEmbedding(config)
+ self.mask_embed = EdgeTamMaskEmbedding(config)
+ self.no_mask_embed = nn.Embedding(1, config.hidden_size)
+
+ self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
+ self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
+ self.input_image_size = config.image_size
+
+ self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
+ self.hidden_size = config.hidden_size
+ self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
+
+ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
+ labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
+ input_shape = (self.input_image_size, self.input_image_size)
+ point_embedding = self.shared_embedding(points, input_shape)
+
+ # torch.where and expanding the labels tensor is required by the ONNX export
+ point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
+
+ # This is required for the ONNX export. The dtype, device need to be explicitly
+ # specified as otherwise torch.onnx.export interprets as double
+ point_embedding = torch.where(
+ labels[..., None] != -10,
+ point_embedding,
+ torch.zeros_like(point_embedding),
+ )
+
+ # Add point embeddings for labels >= 0
+ point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
+
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes += 0.5 # Shift to center of pixel
+ coords = boxes.view(*boxes.shape[:2], 2, 2)
+ # add padding point for consistency with the original implementation
+ coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
+ corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
+ corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
+ corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
+ corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
+ return corner_embedding
+
+ def forward(
+ self,
+ input_points: Optional[tuple[torch.Tensor, torch.Tensor]],
+ input_labels: Optional[torch.Tensor],
+ input_boxes: Optional[torch.Tensor],
+ input_masks: Optional[torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense embeddings.
+
+ Args:
+ points (`torch.Tensor`, *optional*):
+ point coordinates and labels to embed.
+ boxes (`torch.Tensor`, *optional*):
+ boxes to embed
+ masks (`torch.Tensor`, *optional*):
+ masks to embed
+ """
+ sparse_embeddings = None
+ batch_size = 1
+ if input_points is not None:
+ batch_size = input_points.shape[0]
+ if input_labels is None:
+ raise ValueError("If points are provided, labels must also be provided.")
+ point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
+ sparse_embeddings = point_embeddings
+ if input_boxes is not None:
+ batch_size = input_boxes.shape[0]
+ box_embeddings = self._embed_boxes(input_boxes)
+ if sparse_embeddings is None:
+ sparse_embeddings = box_embeddings
+ else:
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
+ if input_masks is not None:
+ dense_embeddings = self.mask_embed(input_masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
+
+ return sparse_embeddings, dense_embeddings
+
+
+class EdgeTamTwoWayTransformer(nn.Module):
+ def __init__(self, config: EdgeTamMaskDecoderConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_hidden_layers = config.num_hidden_layers
+ self.layers = nn.ModuleList()
+
+ for i in range(self.num_hidden_layers):
+ self.layers.append(EdgeTamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
+
+ self.final_attn_token_to_image = EdgeTamAttention(config)
+ self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ point_embeddings: Tensor,
+ image_embeddings: Tensor,
+ image_positional_embeddings: Tensor,
+ attention_similarity: Tensor,
+ target_embedding=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutput]:
+ if image_embeddings is None:
+ raise ValueError("You have to specify an image_embedding")
+
+ image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
+ image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
+
+ # Prepare queries
+ queries = point_embeddings
+ keys = image_embeddings
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ if target_embedding is not None:
+ queries += target_embedding
+
+ queries, keys, _ = layer(
+ queries=queries,
+ keys=keys,
+ query_point_embedding=point_embeddings,
+ key_point_embedding=image_positional_embeddings,
+ attention_similarity=attention_similarity,
+ **kwargs,
+ )
+ # Apply the final attention layer from the points to the image
+ query = queries + point_embeddings
+ key = keys + image_positional_embeddings
+
+ attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
+
+ queries = queries + attn_out
+ queries = self.layer_norm_final_attn(queries)
+ return queries, keys
+
+
+class EdgeTamMaskDecoder(nn.Module):
+ def __init__(self, config: EdgeTamMaskDecoderConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+
+ self.num_multimask_outputs = config.num_multimask_outputs
+ self.num_mask_tokens = config.num_multimask_outputs + 1
+
+ self.iou_token = nn.Embedding(1, self.hidden_size)
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
+
+ self.transformer = EdgeTamTwoWayTransformer(config)
+
+ # should we create a new class for this?
+ self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
+ self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
+ self.upscale_layer_norm = EdgeTamLayerNorm(self.hidden_size // 4, data_format="channels_first")
+ self.activation = nn.GELU()
+
+ mlps_list = []
+ for _ in range(self.num_mask_tokens):
+ mlps_list += [EdgeTamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
+ self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
+ self.iou_prediction_head = EdgeTamFeedForward(
+ self.hidden_size,
+ config.iou_head_hidden_dim,
+ self.num_mask_tokens,
+ config.iou_head_depth,
+ sigmoid_output=True,
+ )
+
+ self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
+ self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
+
+ self.obj_score_token = nn.Embedding(1, self.hidden_size)
+ self.pred_obj_score_head = EdgeTamFeedForward(self.hidden_size, self.hidden_size, 1, 3)
+
+ self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_positional_embeddings: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ high_resolution_features: list[torch.Tensor],
+ attention_similarity: Optional[torch.Tensor] = None,
+ target_embedding: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Args:
+ image_embeddings (`torch.Tensor`):
+ The embeddings from the image encoder.
+ image_positional_embeddings (`torch.Tensor`):
+ Positional encoding with the shape of image_embeddings.
+ sparse_prompt_embeddings (`torch.Tensor`):
+ The embeddings of the points and boxes.
+ dense_prompt_embeddings (`torch.Tensor`):
+ The embeddings of the mask inputs.
+ multimask_output (`bool`):
+ Whether to return multiple masks or a single mask.
+ high_resolution_features (`list[torch.Tensor]`, *optional*):
+ The high-resolution features from the vision encoder.
+ attention_similarity (`torch.Tensor`, *optional*):
+ The attention similarity tensor.
+ target_embedding (`torch.Tensor`, *optional*):
+ The target embedding.
+ """
+ batch_size, num_channels, height, width = image_embeddings.shape
+ point_batch_size = sparse_prompt_embeddings.shape[1]
+ # Concatenate output tokens
+ output_tokens = torch.cat(
+ [
+ self.obj_score_token.weight,
+ self.iou_token.weight,
+ self.mask_tokens.weight,
+ ],
+ dim=0,
+ )
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
+
+ if sparse_prompt_embeddings.shape[0] != 0:
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
+ else:
+ tokens = output_tokens
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
+
+ # Expand per-image data in batch direction to be per-mask
+ image_embeddings = image_embeddings + dense_prompt_embeddings
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
+ # Run the transformer
+ point_embeddings, image_embeddings = self.transformer(
+ point_embeddings=point_embeddings,
+ image_embeddings=image_embeddings,
+ image_positional_embeddings=image_positional_embeddings,
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ **kwargs,
+ )
+ iou_token_out = point_embeddings[:, :, 1, :]
+ mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ image_embeddings = image_embeddings.transpose(2, 3).view(
+ batch_size * point_batch_size, num_channels, height, width
+ )
+
+ feat_s0, feat_s1 = high_resolution_features
+ feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
+ feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
+ upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
+
+ hyper_in_list: list[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ current_mlp = self.output_hypernetworks_mlps[i]
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
+ hyper_in = torch.stack(hyper_in_list, dim=2)
+
+ _, num_channels, height, width = upscaled_embedding.shape
+ upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
+ masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+ object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ mask_slice = slice(1, None)
+ masks = masks[:, :, mask_slice, :, :]
+ iou_pred = iou_pred[:, :, mask_slice]
+ elif self.dynamic_multimask_via_stability and not self.training:
+ mask_slice = slice(0, 1)
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+ else:
+ mask_slice = slice(0, 1)
+ masks = masks[:, :, mask_slice, :, :]
+ iou_pred = iou_pred[:, :, mask_slice]
+
+ sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
+
+ return masks, iou_pred, sam_tokens_out, object_score_logits
+
+ def _get_stability_scores(self, mask_logits):
+ """
+ Compute stability scores of the mask logits based on the IoU between upper and
+ lower thresholds.
+ """
+ mask_logits = mask_logits.flatten(-2)
+ stability_delta = self.dynamic_multimask_stability_delta
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
+ return stability_scores
+
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+ """
+ When outputting a single mask, if the stability score from the current single-mask
+ output (based on output token 0) falls below a threshold, we instead select from
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
+ """
+ # The best mask from multimask output tokens (1~3)
+ multimask_logits = all_mask_logits[:, :, 1:, :, :]
+ multimask_iou_scores = all_iou_scores[:, :, 1:]
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
+ best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ best_scores_inds_expanded = best_scores_inds_expanded.expand(
+ -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
+ )
+ best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
+ best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
+
+ # The mask from singlemask output token 0 and its stability score
+ singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
+ singlemask_iou_scores = all_iou_scores[:, :, 0:1]
+ stability_scores = self._get_stability_scores(singlemask_logits)
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+
+ # Dynamically fall back to best multimask output upon low stability scores.
+ mask_logits_out = torch.where(
+ is_stable[..., None, None].expand_as(singlemask_logits),
+ singlemask_logits,
+ best_multimask_logits,
+ )
+ iou_scores_out = torch.where(
+ is_stable.expand_as(singlemask_iou_scores),
+ singlemask_iou_scores,
+ best_multimask_iou_scores,
+ )
+ return mask_logits_out, iou_scores_out
+
+
+@auto_docstring(
+ custom_intro="""
+ Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and
+ input points and labels, boxes, or masks.
+ """
+)
+class EdgeTamModel(EdgeTamPreTrainedModel):
+ _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
+ # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
+ _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
+ _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamTwoWayAttentionBlock, index=2)}
+ _keys_to_ignore_on_load_unexpected = [
+ r"^memory_.*",
+ r"^mask_downsample.*",
+ r"spatial_perceiver.*",
+ r"^object_pointer_proj.*",
+ r"^temporal_positional_encoding_projection_layer.*",
+ "no_memory_positional_encoding",
+ "no_object_pointer",
+ "occlusion_spatial_embedding_parameter",
+ ]
+
+ def __init__(self, config: EdgeTamConfig):
+ super().__init__(config)
+ self.shared_image_embedding = EdgeTamPositionalEmbedding(config.prompt_encoder_config)
+ self.vision_encoder = AutoModel.from_config(config.vision_config)
+ self.prompt_encoder = EdgeTamPromptEncoder(config.prompt_encoder_config)
+ # The module using it is not a PreTrainedModel subclass so we need this
+ config.mask_decoder_config._attn_implementation = config._attn_implementation
+ self.mask_decoder = EdgeTamMaskDecoder(config.mask_decoder_config)
+
+ self.num_feature_levels = config.vision_config.num_feature_levels
+ self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
+ # a single token to indicate no memory embedding from previous frames
+ self.hidden_dim = config.vision_config.fpn_hidden_size
+ self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+
+ self.post_init()
+
+ def _tie_weights(self):
+ self.prompt_encoder.shared_embedding.positional_embedding.data = (
+ self.shared_image_embedding.positional_embedding.data
+ )
+
+ def get_image_wide_positional_embeddings(self) -> torch.Tensor:
+ size = self.prompt_encoder.image_embedding_size
+ target_device = self.shared_image_embedding.positional_embedding.device
+ target_dtype = self.shared_image_embedding.positional_embedding.dtype
+ grid = torch.ones(size, device=target_device, dtype=target_dtype)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / size[0]
+ x_embed = x_embed / size[1]
+
+ positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
+ return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
+
+ @torch.no_grad()
+ def get_image_embeddings(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> list[torch.Tensor]:
+ r"""
+ Returns the image embeddings by passing the pixel values through the vision encoder.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Input pixel values
+ """
+ batch_size = pixel_values.shape[0]
+ feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs)
+
+ # add no memory embedding to the last feature map
+ feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
+
+ # reshape feature maps to the same shape as the backbone feature sizes
+ image_embeddings = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
+ ]
+
+ return image_embeddings
+
+ @torch.no_grad()
+ def get_prompt_embeddings(
+ self,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ ):
+ r"""
+ Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
+
+ Args:
+ input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
+ Optional input points for the prompt encoder. The padding of the point is automatically done by the
+ processor. `point_batch_size` refers to the number of masks that we want the model to predict per
+ point. The model will output `point_batch_size` times 3 masks in total.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
+ Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
+ processor, or can be fed by the user.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
+ Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
+ processor. users can also pass manually the input boxes.
+ input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
+ Optional input masks for the prompt encoder.
+ """
+ prompt_output = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ return prompt_output
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ image_embeddings: Optional[torch.FloatTensor] = None,
+ multimask_output: bool = True,
+ attention_similarity: Optional[torch.FloatTensor] = None,
+ target_embedding: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> EdgeTamImageSegmentationOutput:
+ r"""
+ input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
+ better results. The points can be obtained by passing a list of list of list to the processor that will
+ create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
+ second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
+ per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
+ multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
+ coordinates of the point. If a different number of points is passed either for each image, or for each
+ mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
+ computation of the embedding will be skipped for these points using the labels.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
+ Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
+ official implementation, there are 3 types of labels
+
+ - `1`: the point is a point that contains the object of interest
+ - `0`: the point is a point that does not contain the object of interest
+ - `-1`: the point corresponds to the background
+
+ We added the label:
+
+ - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
+
+ The padding labels should be automatically done by the processor.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
+ Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
+ much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
+ that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
+ size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
+ In the order (`x1`, `y1`, `x2`, `y2`):
+
+ - `x1`: the x coordinate of the top left point of the input box
+ - `y1`: the y coordinate of the top left point of the input box
+ - `x2`: the x coordinate of the bottom right point of the input box
+ - `y2`: the y coordinate of the bottom right point of the input box
+ input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
+ SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
+ generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
+ manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
+ Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
+ efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
+ method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
+ multimask_output (`bool`, *optional*):
+ In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
+ bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
+ "best" mask, by specifying `multimask_output=False`.
+ attention_similarity (`torch.FloatTensor`, *optional*):
+ Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
+ model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+ target_embedding (`torch.FloatTensor`, *optional*):
+ Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
+ the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoModel, AutoProcessor
+
+ >>> model = AutoModel.from_pretrained("danelcsb/edgetam.1_hiera_tiny")
+ >>> processor = AutoProcessor.from_pretrained("danelcsb/edgetam.1_hiera_tiny")
+
+ >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
+ >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+ >>> input_points = [[[400, 650]]] # 2D location of a window on the car
+ >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
+
+ >>> # Get segmentation mask
+ >>> outputs = model(**inputs)
+
+ >>> # Postprocess masks
+ >>> masks = processor.post_process_masks(
+ ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
+ ... )
+ ```
+ """
+ if not ((pixel_values is None) ^ (image_embeddings is None)):
+ raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
+ if input_points is not None and input_boxes is not None:
+ if input_points.shape[1] != input_boxes.shape[1]:
+ raise ValueError(
+ f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
+ )
+
+ image_positional_embeddings = self.get_image_wide_positional_embeddings()
+ # repeat with batch size
+ batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
+ image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
+
+ vision_attentions = None
+ vision_hidden_states = None
+
+ if pixel_values is not None:
+ feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features(
+ pixel_values,
+ **kwargs,
+ )
+
+ # add no memory embedding to the last feature map
+ feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
+
+ # reshape feature maps to the same shape as the backbone feature sizes
+ image_embeddings = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
+ ]
+
+ if input_points is not None and input_labels is None:
+ input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
+
+ if input_points is None and input_boxes is None:
+ # If no points are provide, pad with an empty point (with label -1)
+ input_points = torch.zeros(
+ batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
+ )
+ input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
+
+ if input_masks is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
+ input_masks = F.interpolate(
+ input_masks.float(),
+ size=self.prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ ).to(input_masks.dtype)
+
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder(
+ image_embeddings=image_embeddings[-1],
+ image_positional_embeddings=image_positional_embeddings,
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ high_resolution_features=image_embeddings[:-1],
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ **kwargs,
+ )
+
+ return EdgeTamImageSegmentationOutput(
+ iou_scores=iou_scores,
+ pred_masks=low_res_multimasks,
+ object_score_logits=object_score_logits,
+ image_embeddings=image_embeddings,
+ vision_hidden_states=vision_hidden_states,
+ vision_attentions=vision_attentions,
+ )
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[
+ list[torch.Tensor],
+ list[torch.Tensor],
+ Optional[tuple[torch.FloatTensor, ...]],
+ Optional[tuple[torch.FloatTensor, ...]],
+ ]:
+ r"""
+ Extract and preprocess image features using the vision encoder.
+
+ Args:
+ pixel_values (`torch.FloatTensor`):
+ Input pixel values of shape `(batch_size, num_channels, height, width)`.
+
+ Returns:
+ `tuple`: A tuple containing:
+ - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels.
+ - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level.
+ - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder.
+ - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder.
+ """
+ vision_outputs: EdgeTamVisionEncoderOutput = self.vision_encoder(
+ pixel_values,
+ **kwargs,
+ )
+
+ feature_maps = vision_outputs.fpn_hidden_states
+ feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
+
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ feature_maps = list(feature_maps)
+ feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
+ feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
+
+ # flatten NxCxHxW to HWxNxC
+ feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
+ feature_maps_position_embeddings = [
+ feature_map_position_embedding.flatten(2).permute(2, 0, 1)
+ for feature_map_position_embedding in feature_maps_position_embeddings
+ ]
+
+ return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions
+
+
+__all__ = ["EdgeTamModel", "EdgeTamVisionModel", "EdgeTamPreTrainedModel"]
diff --git a/src/transformers/models/edgetam/modular_edgetam.py b/src/transformers/models/edgetam/modular_edgetam.py
new file mode 100644
index 000000000000..e26d58d96b81
--- /dev/null
+++ b/src/transformers/models/edgetam/modular_edgetam.py
@@ -0,0 +1,261 @@
+# coding=utf-8
+# Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch SAM 2 model."""
+
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from transformers.models.sam2.configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig
+from transformers.models.sam2.modeling_sam2 import (
+ Sam2Attention,
+ Sam2FeedForward,
+ Sam2LayerNorm,
+ Sam2Model,
+ Sam2PreTrainedModel,
+ Sam2TwoWayAttentionBlock,
+ Sam2VisionEncoderOutput,
+ Sam2VisionModel,
+)
+from transformers.utils.generic import TransformersKwargs, check_model_inputs
+
+from ...configuration_utils import PretrainedConfig
+from ...processing_utils import Unpack
+from ...utils import (
+ auto_docstring,
+)
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+# fix this in modular
+if True:
+ from transformers.models.timm_wrapper.modeling_timm_wrapper import TimmWrapperModel
+
+
+class EdgeTamVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`EdgeTamVisionModel`]. It is used to instantiate a SAM
+ vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
+ defaults will yield a similar configuration to that of SAM 2.1 Hiera-tiny
+ [facebook/EdgeTAM](https://huggingface.co/facebook/EdgeTAM) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ backbone_config (`Union[dict, "PretrainedConfig"]`, *optional*):
+ Configuration for the vision backbone. This is used to instantiate the backbone using
+ `AutoModel.from_config`.
+ backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`):
+ The list of channel dimensions for the backbone.
+ backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[256, 256], [128, 128], [64, 64]]`):
+ The spatial sizes of the feature maps from the backbone.
+ fpn_hidden_size (`int`, *optional*, defaults to 256):
+ The hidden dimension of the FPN.
+ fpn_kernel_size (`int`, *optional*, defaults to 1):
+ The kernel size for the convolutions in the neck.
+ fpn_stride (`int`, *optional*, defaults to 1):
+ The stride for the convolutions in the neck.
+ fpn_padding (`int`, *optional*, defaults to 0):
+ The padding for the convolutions in the neck.
+ fpn_top_down_levels (`List[int]`, *optional*, defaults to `[2, 3]`):
+ The levels for the top-down FPN connections.
+ num_feature_levels (`int`, *optional*, defaults to 3):
+ The number of feature levels from the FPN to use.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the neck.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon for the layer normalization.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+ """
+
+ base_config_key = "vision_config"
+ model_type = "edgetam_vision_model"
+ sub_configs = {
+ "backbone_config": AutoConfig,
+ }
+
+ def __init__(
+ self,
+ backbone_config=None,
+ backbone_channel_list=None,
+ backbone_feature_sizes=None,
+ fpn_hidden_size=256,
+ fpn_kernel_size=1,
+ fpn_stride=1,
+ fpn_padding=0,
+ fpn_top_down_levels=None,
+ num_feature_levels=3,
+ hidden_act="gelu",
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ backbone_channel_list = [384, 192, 96, 48] if backbone_channel_list is None else backbone_channel_list
+ backbone_feature_sizes = (
+ [[256, 256], [128, 128], [64, 64]] if backbone_feature_sizes is None else backbone_feature_sizes
+ )
+ fpn_top_down_levels = [2, 3] if fpn_top_down_levels is None else fpn_top_down_levels
+
+ if isinstance(backbone_config, dict):
+ backbone_config["model_type"] = backbone_config.get("model_type", "timm_wrapper")
+ backbone_config = CONFIG_MAPPING[backbone_config["model_type"]](**backbone_config)
+ elif isinstance(backbone_config, AutoConfig):
+ backbone_config = backbone_config
+ elif backbone_config is None:
+ backbone_config = AutoConfig.from_pretrained(
+ "timm/repvit_m1.dist_in1k",
+ model_args={"in_chans": 3, "features_only": True, "out_indices": [0, 1, 2, 3]},
+ )
+
+ self.backbone_config = backbone_config
+
+ # Neck
+ self.backbone_channel_list = backbone_channel_list
+ self.backbone_feature_sizes = backbone_feature_sizes
+ self.fpn_hidden_size = fpn_hidden_size
+ self.fpn_kernel_size = fpn_kernel_size
+ self.fpn_stride = fpn_stride
+ self.fpn_padding = fpn_padding
+ self.fpn_top_down_levels = fpn_top_down_levels
+ self.num_feature_levels = num_feature_levels
+
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+
+
+class EdgeTamPromptEncoderConfig(Sam2PromptEncoderConfig):
+ pass
+
+
+class EdgeTamMaskDecoderConfig(Sam2MaskDecoderConfig):
+ pass
+
+
+class EdgeTamConfig(Sam2Config):
+ pass
+
+
+class EdgeTamLayerNorm(Sam2LayerNorm):
+ pass
+
+
+class EdgeTamVisionEncoderOutput(Sam2VisionEncoderOutput):
+ pass
+
+
+class EdgeTamAttention(Sam2Attention):
+ pass
+
+
+class EdgeTamTwoWayAttentionBlock(Sam2TwoWayAttentionBlock):
+ pass
+
+
+class EdgeTamFeedForward(Sam2FeedForward):
+ pass
+
+
+@auto_docstring
+class EdgeTamPreTrainedModel(Sam2PreTrainedModel):
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, (nn.LayerNorm, EdgeTamLayerNorm)):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ if isinstance(module, EdgeTamModel):
+ if module.no_memory_embedding is not None:
+ module.no_memory_embedding.data.zero_()
+
+
+@auto_docstring(
+ custom_intro="""
+ The vision model from EdgeTAM without any head or projection on top.
+ """
+)
+class EdgeTamVisionModel(Sam2VisionModel):
+ config_class = EdgeTamVisionConfig
+ main_input_name = "pixel_values"
+ _can_record_outputs = {"hidden_states": TimmWrapperModel, "attentions": TimmWrapperModel}
+
+ def get_input_embeddings(self):
+ raise NotImplementedError("Can't get input embeddings from timm wrapper model")
+
+ @check_model_inputs
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, EdgeTamVisionEncoderOutput]:
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Forward through backbone
+ backbone_output = self.backbone(pixel_values)
+ intermediate_hidden_states = backbone_output.last_hidden_state
+ intermediate_hidden_states = [hidden_state.permute(0, 2, 3, 1) for hidden_state in intermediate_hidden_states]
+
+ fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
+ # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
+ fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
+ fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
+
+ return EdgeTamVisionEncoderOutput(
+ last_hidden_state=intermediate_hidden_states[-1],
+ fpn_hidden_states=fpn_hidden_states,
+ fpn_position_encoding=fpn_position_encoding,
+ )
+
+
+class EdgeTamModel(Sam2Model):
+ _keys_to_ignore_on_load_unexpected = [
+ r"^memory_.*",
+ r"^mask_downsample.*",
+ r"spatial_perceiver.*",
+ r"^object_pointer_proj.*",
+ r"^temporal_positional_encoding_projection_layer.*",
+ "no_memory_positional_encoding",
+ "no_object_pointer",
+ "occlusion_spatial_embedding_parameter",
+ ]
+
+ def get_input_embeddings(self):
+ raise NotImplementedError("Can't get input embeddings from timm wrapper model")
+
+
+__all__ = [
+ "EdgeTamModel",
+ "EdgeTamVisionModel",
+ "EdgeTamPreTrainedModel",
+ "EdgeTamConfig",
+ "EdgeTamVisionConfig",
+ "EdgeTamPromptEncoderConfig",
+ "EdgeTamMaskDecoderConfig",
+]
diff --git a/src/transformers/models/edgetam_video/__init__.py b/src/transformers/models/edgetam_video/__init__.py
new file mode 100644
index 000000000000..669dd64ec304
--- /dev/null
+++ b/src/transformers/models/edgetam_video/__init__.py
@@ -0,0 +1,29 @@
+# coding=utf-8
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_edgetam_video import *
+ from .modeling_edgetam_video import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/edgetam_video/configuration_edgetam_video.py b/src/transformers/models/edgetam_video/configuration_edgetam_video.py
new file mode 100644
index 000000000000..954864397dcb
--- /dev/null
+++ b/src/transformers/models/edgetam_video/configuration_edgetam_video.py
@@ -0,0 +1,435 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/edgetam_video/modular_edgetam_video.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_edgetam_video.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+class EdgeTamVideoPromptEncoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`EdgeTamVideoPromptEncoder`]. The [`EdgeTamVideoPromptEncoder`]
+ module is used to encode the input 2D points and bounding boxes.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the hidden states.
+ image_size (`int`, *optional*, defaults to 1024):
+ The expected output resolution of the image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ mask_input_channels (`int`, *optional*, defaults to 16):
+ The number of channels to be fed to the `MaskDecoder` module.
+ num_point_embeddings (`int`, *optional*, defaults to 4):
+ The number of point embeddings to be used.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the encoder and pooler.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ scale (`float`, *optional*, defaults to 1):
+ The scale factor for the prompt encoder.
+ """
+
+ base_config_key = "prompt_encoder_config"
+
+ def __init__(
+ self,
+ hidden_size=256,
+ image_size=1024,
+ patch_size=16,
+ mask_input_channels=16,
+ num_point_embeddings=4,
+ hidden_act="gelu",
+ layer_norm_eps=1e-6,
+ scale=1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.mask_input_channels = mask_input_channels
+ self.num_point_embeddings = num_point_embeddings
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.scale = scale
+
+
+class EdgeTamVideoMaskDecoderConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`EdgeTamVideoMaskDecoder`]. It is used to instantiate a EDGETAM_VIDEO
+ memory encoder according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the hidden states.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the EDGETAM_VIDEO mask decoder.
+ mlp_dim (`int`, *optional*, defaults to 2048):
+ The dimension of the MLP in the two-way transformer.
+ num_hidden_layers (`int`, *optional*, defaults to 2):
+ The number of hidden layers in the two-way transformer.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ The number of attention heads in the two-way transformer.
+ attention_downsample_rate (`int`, *optional*, defaults to 2):
+ The downsample rate for the attention layers.
+ num_multimask_outputs (`int`, *optional*, defaults to 3):
+ The number of multimask outputs.
+ iou_head_depth (`int`, *optional*, defaults to 3):
+ The depth of the IoU head.
+ iou_head_hidden_dim (`int`, *optional*, defaults to 256):
+ The hidden dimension of the IoU head.
+ dynamic_multimask_via_stability (`bool`, *optional*, defaults to `True`):
+ Whether to use dynamic multimask via stability.
+ dynamic_multimask_stability_delta (`float`, *optional*, defaults to 0.05):
+ The stability delta for the dynamic multimask.
+ dynamic_multimask_stability_thresh (`float`, *optional*, defaults to 0.98):
+ The stability threshold for the dynamic multimask.
+
+ """
+
+ base_config_key = "mask_decoder_config"
+
+ def __init__(
+ self,
+ hidden_size=256,
+ hidden_act="gelu",
+ mlp_dim=2048,
+ num_hidden_layers=2,
+ num_attention_heads=8,
+ attention_downsample_rate=2,
+ num_multimask_outputs=3,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ dynamic_multimask_via_stability=True,
+ dynamic_multimask_stability_delta=0.05,
+ dynamic_multimask_stability_thresh=0.98,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_multimask_outputs = num_multimask_outputs
+ self.hidden_act = hidden_act
+ self.iou_head_depth = iou_head_depth
+ self.iou_head_hidden_dim = iou_head_hidden_dim
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
+
+ # TwoWayTransformer configuration
+ self.num_hidden_layers = num_hidden_layers
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.mlp_dim = mlp_dim
+ self.attention_downsample_rate = attention_downsample_rate
+
+
+class EdgeTamVideoConfig(PretrainedConfig):
+ r"""
+ [`EdgeTamVideoConfig`] is the configuration class to store the configuration of a [`EdgeTamVideoModel`]. It is used to instantiate a
+ EDGETAM model according to the specified arguments, defining the memory attention, memory encoder, and image encoder
+ configs. Instantiating a configuration defaults will yield a similar configuration to that of the SAM 2.1 Hiera-tiny
+ [facebook/EdgeTAM](https://huggingface.co/facebook/EdgeTAM) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (Union[`dict`, `EdgeTamVideoVisionConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`EdgeTamVideoVisionConfig`].
+ prompt_encoder_config (Union[`dict`, `EdgeTamVideoPromptEncoderConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`EdgeTamVideoPromptEncoderConfig`].
+ mask_decoder_config (Union[`dict`, `EdgeTamVideoMaskDecoderConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`EdgeTamMaskDecoderConfig`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ Standard deviation for parameter initialization.
+ num_maskmem (`int`, *optional*, defaults to 7):
+ The number of memory slots for the mask memory.
+ image_size (`int`, *optional*, defaults to 1024):
+ The size of the input images.
+ sigmoid_scale_for_mem_enc (`float`, *optional*, defaults to 20.0):
+ Scale factor for the sigmoid function in the memory encoder.
+ sigmoid_bias_for_mem_enc (`float`, *optional*, defaults to -10.0):
+ Bias for the sigmoid function in the memory encoder.
+ enable_occlusion_spatial_embedding (`bool`, *optional*, defaults to `True`):
+ Whether to enable spatial embedding for occlusions.
+ multimask_output_in_sam (`bool`, *optional*, defaults to `True`):
+ Whether to output multiple masks from the SAM head.
+ multimask_min_pt_num (`int`, *optional*, defaults to 0):
+ The minimum number of points to trigger multimask output.
+ multimask_max_pt_num (`int`, *optional*, defaults to 1):
+ The maximum number of points to trigger multimask output.
+ multimask_output_for_tracking (`bool`, *optional*, defaults to `True`):
+ Whether to use multimask output for tracking.
+ max_object_pointers_in_encoder (`int`, *optional*, defaults to 16):
+ The maximum number of object pointers in the encoder.
+ enable_temporal_pos_encoding_for_object_pointers (`bool`, *optional*, defaults to `True`):
+ Whether to enable temporal positional encoding for object pointers.
+ memory_attention_hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the memory attention hidden states.
+ memory_attention_num_layers (`int`, *optional*, defaults to 2):
+ The number of layers in the memory attention module.
+ memory_attention_num_attention_heads (`int`, *optional*, defaults to 1):
+ Number of attention heads for each attention layer in the memory attention.
+ memory_attention_downsample_rate (`int`, *optional*, defaults to 1):
+ The downsample rate for the attention layers.
+ memory_attention_mlp_hidden_size (`int`, *optional*, defaults to 2048):
+ The dimension of the feedforward network in the memory attention module.
+ memory_attention_mlp_hidden_act (`str`, *optional*, defaults to `"relu"`):
+ The non-linear activation function in the feedforward network in the memory attention module.
+ memory_attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout rate for the memory attention module.
+ memory_attention_rope_theta (`float`, *optional*, defaults to 10000):
+ The Rope theta parameter.
+ memory_attention_rope_feat_sizes (`Tuple[int, int]`, *optional*, defaults to `[64, 64]`):
+ The feature sizes for the Rope positional encoding.
+ memory_attention_rope_k_sizes (`List[int]`, *optional*, defaults to `[16, 16]`):
+ The key feature sizes for the RoPE positional encoding in memory attention.
+ memory_attention_rope_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout rate for the Rope positional encoding.
+ perceiver_resampler_num_latents (`int`, *optional*, defaults to 256):
+ The number of 1D latent tokens in the perceiver resampler.
+ perceiver_resampler_num_latents_2d (`int`, *optional*, defaults to 256):
+ The number of 2D latent tokens in the perceiver resampler.
+ perceiver_resampler_hidden_size (`int`, *optional*, defaults to 64):
+ The hidden size of the perceiver resampler.
+ perceiver_resampler_mlp_intermediate_size (`int`, *optional*, defaults to 256):
+ The intermediate size of the feedforward network in the perceiver resampler.
+ perceiver_resampler_num_attention_heads (`int`, *optional*, defaults to 1):
+ The number of attention heads in the perceiver resampler.
+ perceiver_resampler_attention_head_dim (`int`, *optional*, defaults to 64):
+ The dimension of each attention head in the perceiver resampler.
+ perceiver_resampler_num_layers (`int`, *optional*, defaults to 2):
+ The number of layers in the perceiver resampler.
+ perceiver_resampler_hidden_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout rate for the hidden layers in the perceiver resampler.
+ perceiver_resampler_attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout rate for the attention layers in the perceiver resampler.
+ memory_encoder_hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the memory encoder hidden states.
+ memory_encoder_output_channels (`int`, *optional*, defaults to 64):
+ The number of output channels for the memory encoder.
+ mask_downsampler_embed_dim (`int`, *optional*, defaults to 256):
+ The dimension of the mask downsampler embedding.
+ memory_fuser_intermediate_dim (`int`, *optional*, defaults to 1024):
+ The intermediate dimension of the memory fuser feedforward network.
+ mask_downsampler_kernel_size (`int`, *optional*, defaults to 3):
+ The kernel size for the mask downsampler.
+ mask_downsampler_stride (`int`, *optional*, defaults to 2):
+ The stride for the mask downsampler.
+ mask_downsampler_padding (`int`, *optional*, defaults to 1):
+ The padding for the mask downsampler.
+ mask_downsampler_total_stride (`int`, *optional*, defaults to 16):
+ The total stride for the mask downsampler.
+ mask_downsampler_hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the mask downsampler.
+ memory_fuser_num_layers (`int`, *optional*, defaults to 2):
+ The number of layers in the memory fuser.
+ memory_fuser_embed_dim (`int`, *optional*, defaults to 256):
+ The dimension of the memory fuser embedding.
+ memory_fuser_kernel_size (`int`, *optional*, defaults to 7):
+ The kernel size for the memory fuser.
+ memory_fuser_padding (`int`, *optional*, defaults to 3):
+ The padding for the memory fuser.
+ memory_fuser_layer_scale_init_value (`float`, *optional*, defaults to 1e-06):
+ The initial value for the layer scale in the memory fuser.
+ memory_fuser_hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the memory fuser.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... EdgeTamVisionConfig,
+ ... EdgeTamVideoPromptEncoderConfig,
+ ... EdgeTamVideoMaskDecoderConfig,
+ ... EdgeTamVideoModel,
+ ... EdgeTamVideoConfig,
+ ... )
+
+ >>> # Initializing a EdgeTamVideoConfig with `"facebook/edgetam.1_hiera_tiny"` style configuration
+ >>> configuration = EdgeTamVideoConfig()
+
+ >>> # Initializing a EdgeTamVideoModel (with random weights) from the `"facebook/edgetam.1_hiera_tiny"` style configuration
+ >>> model = EdgeTamVideoModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+
+ >>> # We can also initialize a EdgeTamConfig from a EdgeTamVisionConfig, EdgeTamPromptEncoderConfig, and EdgeTamMaskDecoderConfig
+
+ >>> # Initializing EDGETAM vision encoder, memory attention, and memory encoder configurations
+ >>> vision_config = EdgeTamVisionConfig()
+ >>> prompt_encoder_config = EdgeTamVideoPromptEncoderConfig()
+ >>> mask_decoder_config = EdgeTamVideoMaskDecoderConfig()
+
+ >>> config = EdgeTamVideoConfig(vision_config, prompt_encoder_config, mask_decoder_config)
+ ```"""
+
+ model_type = "edgetam_video"
+ sub_configs = {
+ "vision_config": AutoConfig,
+ "prompt_encoder_config": EdgeTamVideoPromptEncoderConfig,
+ "mask_decoder_config": EdgeTamVideoMaskDecoderConfig,
+ }
+
+ def __init__(
+ self,
+ vision_config=None,
+ prompt_encoder_config=None,
+ mask_decoder_config=None,
+ initializer_range=0.02,
+ num_maskmem=7,
+ image_size=1024,
+ sigmoid_scale_for_mem_enc=20.0,
+ sigmoid_bias_for_mem_enc=-10.0,
+ enable_occlusion_spatial_embedding=True,
+ multimask_output_in_sam=True,
+ multimask_min_pt_num=0,
+ multimask_max_pt_num=1,
+ multimask_output_for_tracking=True,
+ max_object_pointers_in_encoder=16,
+ enable_temporal_pos_encoding_for_object_pointers=True,
+ # memory attention
+ memory_attention_hidden_size=256,
+ memory_attention_num_layers=2,
+ memory_attention_num_attention_heads=1,
+ memory_attention_downsample_rate=1,
+ memory_attention_mlp_hidden_size=2048,
+ memory_attention_mlp_hidden_act="relu",
+ memory_attention_dropout=0.1,
+ memory_attention_rope_theta=10000,
+ memory_attention_rope_feat_sizes=None,
+ memory_attention_rope_k_sizes=None,
+ memory_attention_rope_dropout=0.1,
+ # spatial perceiver resampler
+ perceiver_resampler_num_latents=256,
+ perceiver_resampler_num_latents_2d=256,
+ perceiver_resampler_hidden_size=64,
+ perceiver_resampler_mlp_intermediate_size=256,
+ perceiver_resampler_num_attention_heads=1,
+ perceiver_resampler_attention_head_dim=64,
+ perceiver_resampler_num_layers=2,
+ perceiver_resampler_hidden_dropout=0.0,
+ perceiver_resampler_attention_dropout=0.0,
+ # memory encoder
+ memory_encoder_hidden_size=256,
+ memory_encoder_output_channels=64,
+ mask_downsampler_embed_dim=256,
+ memory_fuser_intermediate_dim=1024,
+ mask_downsampler_kernel_size=3,
+ mask_downsampler_stride=2,
+ mask_downsampler_padding=1,
+ mask_downsampler_total_stride=16,
+ mask_downsampler_hidden_act="gelu",
+ memory_fuser_num_layers=2,
+ memory_fuser_embed_dim=256,
+ memory_fuser_kernel_size=7,
+ memory_fuser_padding=3,
+ memory_fuser_layer_scale_init_value=1e-6,
+ memory_fuser_hidden_act="gelu",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ vision_config = vision_config if vision_config is not None else {}
+ prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
+ mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {}
+ memory_attention_rope_feat_sizes = (
+ [64, 64] if memory_attention_rope_feat_sizes is None else memory_attention_rope_feat_sizes
+ )
+ memory_attention_rope_k_sizes = (
+ [16, 16] if memory_attention_rope_k_sizes is None else memory_attention_rope_k_sizes
+ )
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "sam2_vision_model")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ if isinstance(prompt_encoder_config, EdgeTamVideoPromptEncoderConfig):
+ prompt_encoder_config = prompt_encoder_config.to_dict()
+ if isinstance(mask_decoder_config, EdgeTamVideoMaskDecoderConfig):
+ mask_decoder_config = mask_decoder_config.to_dict()
+
+ self.vision_config = vision_config
+ self.prompt_encoder_config = EdgeTamVideoPromptEncoderConfig(**prompt_encoder_config)
+ self.mask_decoder_config = EdgeTamVideoMaskDecoderConfig(**mask_decoder_config)
+
+ self.initializer_range = initializer_range
+ self.num_maskmem = num_maskmem # default 1 input frame + 6 previous frames
+ self.image_size = image_size
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc # scale factor for mask sigmoid prob
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc # bias factor for mask sigmoid prob
+ self.enable_occlusion_spatial_embedding = enable_occlusion_spatial_embedding
+ self.multimask_output_in_sam = multimask_output_in_sam
+ self.multimask_min_pt_num = multimask_min_pt_num
+ self.multimask_max_pt_num = multimask_max_pt_num
+ self.multimask_output_for_tracking = multimask_output_for_tracking
+ self.max_object_pointers_in_encoder = max_object_pointers_in_encoder
+ self.enable_temporal_pos_encoding_for_object_pointers = enable_temporal_pos_encoding_for_object_pointers
+
+ # memory attention
+ self.memory_attention_hidden_size = memory_attention_hidden_size
+ self.memory_attention_num_layers = memory_attention_num_layers
+ self.memory_attention_num_attention_heads = memory_attention_num_attention_heads
+ self.memory_attention_downsample_rate = memory_attention_downsample_rate
+ self.memory_attention_mlp_hidden_size = memory_attention_mlp_hidden_size
+ self.memory_attention_mlp_hidden_act = memory_attention_mlp_hidden_act
+ self.memory_attention_dropout = memory_attention_dropout
+ self.memory_attention_rope_theta = memory_attention_rope_theta
+ self.memory_attention_rope_feat_sizes = memory_attention_rope_feat_sizes
+ self.memory_attention_rope_k_sizes = memory_attention_rope_k_sizes
+ self.memory_attention_rope_dropout = memory_attention_rope_dropout
+
+ # spatial perceiver resampler
+ self.perceiver_resampler_num_latents = perceiver_resampler_num_latents
+ self.perceiver_resampler_num_latents_2d = perceiver_resampler_num_latents_2d
+ self.perceiver_resampler_hidden_size = perceiver_resampler_hidden_size
+ self.perceiver_resampler_mlp_intermediate_size = perceiver_resampler_mlp_intermediate_size
+ self.perceiver_resampler_attention_head_dim = perceiver_resampler_attention_head_dim
+ self.perceiver_resampler_num_attention_heads = perceiver_resampler_num_attention_heads
+ self.perceiver_resampler_num_layers = perceiver_resampler_num_layers
+ self.perceiver_resampler_hidden_dropout = perceiver_resampler_hidden_dropout
+ self.perceiver_resampler_attention_dropout = perceiver_resampler_attention_dropout
+
+ # memory encoder
+ self.memory_encoder_hidden_size = memory_encoder_hidden_size
+ self.memory_encoder_output_channels = memory_encoder_output_channels
+ self.mask_downsampler_embed_dim = mask_downsampler_embed_dim
+ self.mask_downsampler_kernel_size = mask_downsampler_kernel_size
+ self.mask_downsampler_stride = mask_downsampler_stride
+ self.mask_downsampler_padding = mask_downsampler_padding
+ self.mask_downsampler_total_stride = mask_downsampler_total_stride
+ self.mask_downsampler_hidden_act = mask_downsampler_hidden_act
+ self.memory_fuser_num_layers = memory_fuser_num_layers
+ self.memory_fuser_embed_dim = memory_fuser_embed_dim
+ self.memory_fuser_intermediate_dim = memory_fuser_intermediate_dim
+ self.memory_fuser_kernel_size = memory_fuser_kernel_size
+ self.memory_fuser_padding = memory_fuser_padding
+ self.memory_fuser_layer_scale_init_value = memory_fuser_layer_scale_init_value
+ self.memory_fuser_hidden_act = memory_fuser_hidden_act
+
+
+__all__ = ["EdgeTamVideoMaskDecoderConfig", "EdgeTamVideoPromptEncoderConfig", "EdgeTamVideoConfig"]
diff --git a/src/transformers/models/edgetam_video/convert_edgetam_video_to_hf.py b/src/transformers/models/edgetam_video/convert_edgetam_video_to_hf.py
new file mode 100644
index 000000000000..6290bef5e1c8
--- /dev/null
+++ b/src/transformers/models/edgetam_video/convert_edgetam_video_to_hf.py
@@ -0,0 +1,320 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Convert SAM checkpoints from the original repository.
+
+URL: https://github.com/facebookresearch/segment-anything-2.
+"""
+
+import argparse
+import re
+
+import numpy as np
+import requests
+import torch
+from huggingface_hub import hf_hub_download
+from PIL import Image
+
+from transformers import (
+ EdgeTamVideoConfig,
+ EdgeTamVideoMaskDecoderConfig,
+ EdgeTamVideoModel,
+ EdgeTamVideoPromptEncoderConfig,
+ EdgeTamVisionConfig,
+ Sam2ImageProcessorFast,
+ Sam2VideoProcessor,
+ Sam2VideoVideoProcessor,
+ TimmWrapperConfig,
+)
+
+
+def get_config(model_name):
+ backbone_config = TimmWrapperConfig.from_pretrained(
+ "timm/repvit_m1.dist_in1k",
+ model_args={"in_chans": 3, "features_only": True, "out_indices": (0, 1, 2, 3)},
+ )
+ vision_config = EdgeTamVisionConfig(backbone_config=backbone_config)
+
+ prompt_encoder_config = EdgeTamVideoPromptEncoderConfig()
+ mask_decoder_config = EdgeTamVideoMaskDecoderConfig()
+ enable_temporal_pos_encoding_for_object_pointers = False
+ enable_occlusion_spatial_embedding = False
+
+ config = EdgeTamVideoConfig(
+ vision_config=vision_config,
+ prompt_encoder_config=prompt_encoder_config,
+ mask_decoder_config=mask_decoder_config,
+ enable_temporal_pos_encoding_for_object_pointers=enable_temporal_pos_encoding_for_object_pointers,
+ enable_occlusion_spatial_embedding=enable_occlusion_spatial_embedding,
+ )
+
+ return config
+
+
+KEYS_TO_MODIFY_MAPPING = {
+ "iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
+ "iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
+ "iou_prediction_head.layers.2": "iou_prediction_head.proj_out",
+ "mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1",
+ "mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm",
+ "mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2",
+ "mask_downscaling.0": "mask_embed.conv1",
+ "mask_downscaling.1": "mask_embed.layer_norm1",
+ "mask_downscaling.3": "mask_embed.conv2",
+ "mask_downscaling.4": "mask_embed.layer_norm2",
+ "mask_downscaling.6": "mask_embed.conv3",
+ "dwconv": "depthwise_conv",
+ "pwconv": "pointwise_conv",
+ "fuser": "memory_fuser",
+ "point_embeddings": "point_embed",
+ "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding",
+ "obj_ptr_tpos_proj": "temporal_positional_encoding_projection_layer",
+ "no_obj_embed_spatial": "occlusion_spatial_embedding_parameter",
+ "sam_prompt_encoder": "prompt_encoder",
+ "sam_mask_decoder": "mask_decoder",
+ "maskmem_tpos_enc": "memory_temporal_positional_encoding",
+ "gamma": "scale",
+ "image_encoder.neck": "vision_encoder.neck",
+ "image_encoder": "vision_encoder.backbone",
+ "neck.0": "neck.conv1",
+ "neck.1": "neck.layer_norm1",
+ "neck.2": "neck.conv2",
+ "neck.3": "neck.layer_norm2",
+ "pix_feat_proj": "feature_projection",
+ "patch_embed.proj": "patch_embed.projection",
+ "no_mem_embed": "no_memory_embedding",
+ "no_mem_pos_enc": "no_memory_positional_encoding",
+ "obj_ptr": "object_pointer",
+ ".norm": ".layer_norm",
+ "trunk.": "",
+ "out_proj": "o_proj",
+ "body.": "timm_model.",
+ "ff.0": "mlp.layer_norm",
+ "ff.1": "mlp.up_proj",
+ "ff.3": "mlp.down_proj",
+}
+
+
+def replace_keys(state_dict):
+ model_state_dict = {}
+ output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*"
+ output_mask_decoder_mlps_pattern = r"mask_decoder.transformer.layers.(\d+).mlp.layers.(\d+).*"
+ output_mask_decoder_score_head_pattern = r"mask_decoder.pred_obj_score_head.layers.(\d+).*"
+ output_vision_encoder_mlps_pattern = r"vision_encoder.backbone.blocks.(\d+).mlp.layers.(\d+).*"
+ output_vision_encoder_neck_pattern = r"vision_encoder.neck.convs.(\d+).conv"
+ output_memory_encoder_projection_pattern = r"memory_encoder.o_proj.*"
+ memory_attention_pattern = r"memory_attention.*"
+ output_object_pointer_proj_pattern = r"object_pointer_proj.layers.(\d+).*"
+ output_memory_encoder_mask_downsampler_pattern = r"memory_encoder.mask_downsampler.encoder.(\d+).*"
+ perceiver_resampler_patterns = {
+ r"spatial_perceiver.latents": r"spatial_perceiver.latents_1d",
+ r"spatial_perceiver.latents_1d_2d": r"spatial_perceiver.latents_2d",
+ r"spatial_perceiver.layers.(\d+).attn.layer_norm_x": r"spatial_perceiver.layers.\1.layer_norm_input",
+ r"spatial_perceiver.layers.(\d+).attn.layer_norm_latents": r"spatial_perceiver.layers.\1.layer_norm_latents",
+ r"spatial_perceiver.layers.(\d+).self_attn.layer_norm": r"spatial_perceiver.layers.\1.layer_norm_self",
+ r"spatial_perceiver.layers.(\d+).attn.to_q": r"spatial_perceiver.layers.\1.cross_attention.q_proj",
+ r"spatial_perceiver.layers.(\d+).attn.to_kv": r"spatial_perceiver.layers.\1.cross_attention.kv_proj_combined",
+ r"spatial_perceiver.layers.(\d+).attn.to_out": r"spatial_perceiver.layers.\1.cross_attention.o_proj",
+ r"spatial_perceiver.layers.(\d+).self_attn.to_q": r"spatial_perceiver.layers.\1.self_attention.q_proj",
+ r"spatial_perceiver.layers.(\d+).self_attn.to_kv": r"spatial_perceiver.layers.\1.self_attention.kv_proj_combined",
+ r"spatial_perceiver.layers.(\d+).self_attn.to_out": r"spatial_perceiver.layers.\1.self_attention.o_proj",
+ r"spatial_perceiver.layers.(\d+).attn": r"spatial_perceiver.layers.\1.cross_attention",
+ r"spatial_perceiver.layers.(\d+).self_attn": r"spatial_perceiver.layers.\1.self_attention",
+ }
+
+ for key, value in state_dict.items():
+ for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
+ if key_to_modify in key:
+ key = key.replace(key_to_modify, new_key)
+
+ for pattern, replacement in perceiver_resampler_patterns.items():
+ if re.match(pattern, key):
+ key = re.sub(pattern, replacement, key)
+
+ # vision_encoder.blocks.0.mlp.layers.1.weight -> vision_encoder.blocks.0.mlp.proj_out.weight
+ if re.match(output_vision_encoder_mlps_pattern, key):
+ layer_nb = int(re.match(output_vision_encoder_mlps_pattern, key).group(2))
+ if layer_nb == 0:
+ key = key.replace("layers.0", "proj_in")
+ elif layer_nb == 1:
+ key = key.replace("layers.1", "proj_out")
+
+ if re.match(memory_attention_pattern, key):
+ key = key.replace("linear1", "mlp.up_proj")
+ key = key.replace("linear2", "mlp.down_proj")
+
+ # mask_decoder.transformer.layers.0.mlp.layers.1.weight -> mask_decoder.transformer.layers.1.mlp.proj_out.weight
+ if re.match(output_mask_decoder_mlps_pattern, key):
+ layer_nb = int(re.match(output_mask_decoder_mlps_pattern, key).group(2))
+ if layer_nb == 0:
+ key = key.replace("mlp.layers.0", "mlp.proj_in")
+ elif layer_nb == 1:
+ key = key.replace("mlp.layers.1", "mlp.proj_out")
+
+ # mask_decoder.pred_obj_score_head.layers.1.weight -> mask_decoder.pred_obj_score_head.proj_in.weight
+ if re.match(output_mask_decoder_score_head_pattern, key):
+ layer_nb = int(re.match(output_mask_decoder_score_head_pattern, key).group(1))
+ if layer_nb == 0:
+ key = key.replace("layers.0", "proj_in")
+ elif layer_nb == 1:
+ key = key.replace("layers.1", "layers.0")
+ elif layer_nb == 2:
+ key = key.replace("layers.2", "proj_out")
+
+ if re.match(output_hypernetworks_mlps_pattern, key):
+ layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2))
+ if layer_nb == 0:
+ key = key.replace("layers.0", "proj_in")
+ elif layer_nb == 1:
+ key = key.replace("layers.1", "layers.0")
+ elif layer_nb == 2:
+ key = key.replace("layers.2", "proj_out")
+
+ # vision_encoder.neck.convs.1.conv.bias -> vision_encoder.neck.convs.1.bias
+ if re.match(output_vision_encoder_neck_pattern, key):
+ key = key.replace(".conv.", ".")
+
+ # memory_encoder.o_proj.weight -> memory_encoder.projection.weight
+ if re.match(output_memory_encoder_projection_pattern, key):
+ key = key.replace(".o_proj.", ".projection.")
+
+ if re.match(output_object_pointer_proj_pattern, key):
+ layer_nb = int(re.match(output_object_pointer_proj_pattern, key).group(1))
+ if layer_nb == 0:
+ key = key.replace("layers.0", "proj_in")
+ elif layer_nb == 1:
+ key = key.replace("layers.1", "layers.0")
+ elif layer_nb == 2:
+ key = key.replace("layers.2", "proj_out")
+
+ key = key.replace("layers.2", "proj_out")
+
+ if re.match(output_memory_encoder_mask_downsampler_pattern, key):
+ layer_nb = int(re.match(output_memory_encoder_mask_downsampler_pattern, key).group(1))
+ if layer_nb == 12:
+ key = key.replace(f"encoder.{layer_nb}", "final_conv")
+ elif layer_nb % 3 == 0:
+ key = key.replace(f"encoder.{layer_nb}", f"layers.{layer_nb // 3}.conv")
+ elif layer_nb % 3 == 1:
+ key = key.replace(f"encoder.{layer_nb}", f"layers.{layer_nb // 3}.layer_norm")
+ if "kv_proj_combined" in key:
+ # Split the weight tensor in half along dimension 0 (output dimension)
+ k_weight, v_weight = torch.chunk(value, 2, dim=0)
+ # Create the k_proj and v_proj keys
+ k_key = key.replace("kv_proj_combined", "k_proj")
+ v_key = key.replace("kv_proj_combined", "v_proj")
+ model_state_dict[k_key] = k_weight
+ model_state_dict[v_key] = v_weight
+ continue
+
+ model_state_dict[key] = value
+
+ model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[
+ "prompt_encoder.shared_embedding.positional_embedding"
+ ]
+ model_state_dict["prompt_encoder.point_embed.weight"] = torch.cat(
+ [model_state_dict.pop(f"prompt_encoder.point_embed.{i}.weight") for i in range(4)],
+ dim=0,
+ )
+
+ return model_state_dict
+
+
+def convert_edgetam_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub, run_sanity_check):
+ config = get_config(model_name)
+
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
+ state_dict = replace_keys(state_dict)
+
+ image_processor = Sam2ImageProcessorFast()
+ video_processor = Sam2VideoVideoProcessor()
+ processor = Sam2VideoProcessor(image_processor=image_processor, video_processor=video_processor)
+ hf_model = EdgeTamVideoModel(config)
+ hf_model.eval()
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=True)
+ hf_model = hf_model.to(device)
+ print("Missing keys:", missing_keys)
+ print("Unexpected keys:", unexpected_keys)
+
+ if run_sanity_check:
+ img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+
+ input_points = [[[[1000, 600]]]]
+ input_labels = [[[1]]]
+
+ inputs = processor(
+ images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt"
+ ).to(device)
+
+ with torch.no_grad():
+ output = hf_model._single_frame_forward(**inputs)
+ scores = output.iou_scores.squeeze()
+
+ assert torch.allclose(scores, torch.tensor([0.0356, 0.2141, 0.9707]).cuda(), atol=1e-3)
+
+ if pytorch_dump_folder is not None:
+ processor.save_pretrained(pytorch_dump_folder)
+ hf_model.save_pretrained(pytorch_dump_folder)
+
+ if push_to_hub:
+ repo_id = f"yonigozlan/{pytorch_dump_folder.split('/')[-1]}"
+ processor.push_to_hub(repo_id)
+ hf_model.push_to_hub(repo_id)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ choices = ["EdgeTAM"]
+ parser.add_argument(
+ "--model_name",
+ default="EdgeTAM",
+ choices=choices,
+ type=str,
+ help="Name of the original model to convert",
+ )
+ parser.add_argument(
+ "--checkpoint_path",
+ type=str,
+ required=False,
+ help="Path to the original checkpoint",
+ )
+ parser.add_argument("--pytorch_dump_folder_path", default="", type=str, help="Path to the output PyTorch model.")
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether to push the model and processor to the hub after converting",
+ )
+ parser.add_argument(
+ "--run_sanity_check",
+ action="store_true",
+ help="Whether to run the sanity check after converting",
+ )
+
+ args = parser.parse_args()
+
+ hf_model_name = args.model_name.replace("_", "-")
+ checkpoint_path = (
+ hf_hub_download(f"facebook/{hf_model_name}", f"{args.model_name.lower()}.pt")
+ if args.checkpoint_path is None
+ else args.checkpoint_path
+ )
+
+ convert_edgetam_checkpoint(
+ args.model_name, checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub, args.run_sanity_check
+ )
diff --git a/src/transformers/models/edgetam_video/modeling_edgetam_video.py b/src/transformers/models/edgetam_video/modeling_edgetam_video.py
new file mode 100644
index 000000000000..3ba7ab4ebf2f
--- /dev/null
+++ b/src/transformers/models/edgetam_video/modeling_edgetam_video.py
@@ -0,0 +1,3062 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/edgetam_video/modular_edgetam_video.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_edgetam_video.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections import OrderedDict
+from collections.abc import Iterator
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from tqdm import tqdm
+
+from transformers.utils.generic import OutputRecorder
+
+from ...activations import ACT2FN
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...pytorch_utils import compile_compatible_method_lru_cache
+from ...utils import ModelOutput, auto_docstring
+from ...utils.generic import TransformersKwargs
+from ..auto import AutoModel
+from .configuration_edgetam_video import (
+ EdgeTamVideoConfig,
+ EdgeTamVideoMaskDecoderConfig,
+ EdgeTamVideoPromptEncoderConfig,
+)
+
+
+class EdgeTamVideoLayerNorm(nn.LayerNorm):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
+ super().__init__(normalized_shape, eps=eps, **kwargs)
+ if data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {data_format}")
+ self.data_format = data_format
+
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
+ """
+ if self.data_format == "channels_first":
+ features = features.permute(0, 2, 3, 1)
+ features = super().forward(features)
+ features = features.permute(0, 3, 1, 2)
+ else:
+ features = super().forward(features)
+ return features
+
+
+# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
+class EdgeTamVideoMemoryFuserCXBlock(GradientCheckpointingLayer):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.depthwise_conv = nn.Conv2d(
+ config.memory_fuser_embed_dim,
+ config.memory_fuser_embed_dim,
+ kernel_size=config.memory_fuser_kernel_size,
+ padding=config.memory_fuser_padding,
+ groups=config.memory_fuser_embed_dim,
+ ) # depthwise conv
+ self.layer_norm = EdgeTamVideoLayerNorm(config.memory_fuser_embed_dim, eps=1e-6, data_format="channels_first")
+ self.activation = ACT2FN[config.memory_fuser_hidden_act]
+ self.pointwise_conv1 = nn.Linear(
+ config.memory_fuser_embed_dim, config.memory_fuser_intermediate_dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.pointwise_conv2 = nn.Linear(config.memory_fuser_intermediate_dim, config.memory_fuser_embed_dim)
+ self.scale = nn.Parameter(
+ config.memory_fuser_layer_scale_init_value * torch.ones(config.memory_fuser_embed_dim),
+ requires_grad=True,
+ )
+
+ def forward(self, hidden_states):
+ input = hidden_states
+ hidden_states = self.depthwise_conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.pointwise_conv2(hidden_states)
+ hidden_states = self.scale * hidden_states
+ hidden_states = hidden_states.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ hidden_states = input + hidden_states
+ return hidden_states
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
+class EdgeTamVideoVisionEncoderOutput(ModelOutput):
+ r"""
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ fpn_hidden_states (`tuple(torch.FloatTensor)`):
+ Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
+ `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
+ fpn_position_encoding (`tuple(torch.FloatTensor)`):
+ Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
+ `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
+ model at the output of each stage.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ fpn_hidden_states: Optional[torch.FloatTensor] = None
+ fpn_position_encoding: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+class EdgeTamVideoVisionRotaryEmbedding(nn.Module):
+ """
+ Vision Rotary Position Embedding for SAM2, following transformers library standards.
+ Supports 2D (axial) rotary embeddings for spatial dimensions.
+ """
+
+ def __init__(self, config: EdgeTamVideoConfig, end_x: Optional[int] = None, end_y: Optional[int] = None):
+ super().__init__()
+ dim = config.memory_attention_hidden_size // (
+ config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
+ )
+ # Ensure even dimension for proper axial splitting
+ if dim % 4 != 0:
+ raise ValueError("Dimension must be divisible by 4 for axial RoPE")
+ end_x, end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
+ freqs = 1.0 / (config.memory_attention_rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+
+ # Generate 2D position indices for axial rotary embedding
+ flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
+ x_positions = flattened_indices % end_x
+ y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
+ freqs_x = torch.outer(x_positions, freqs).float()
+ freqs_y = torch.outer(y_positions, freqs).float()
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
+ # directly register the cos and sin embeddings as we have a fixed feature shape
+ self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
+ self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
+
+ @torch.no_grad()
+ def forward(self) -> tuple[torch.Tensor, torch.Tensor]:
+ # As the feature map size is fixed, we can just return the pre-computed embeddings.
+ return self.rope_embeddings_cos, self.rope_embeddings_sin
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class EdgeTamVideoAttention(nn.Module):
+ """
+ EDGETAM_VIDEO's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
+ values.
+ """
+
+ def __init__(self, config, downsample_rate=None):
+ super().__init__()
+ downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.internal_dim = config.hidden_size // downsample_rate
+ self.num_attention_heads = config.num_attention_heads
+ self.head_dim = self.internal_dim // config.num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_similarity: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # Input projections
+ batch_size, point_batch_size = query.shape[:2]
+ new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
+
+ query = self.q_proj(query).view(*new_shape).transpose(1, 2)
+ key = self.k_proj(key).view(*new_shape).transpose(1, 2)
+ value = self.v_proj(value).view(*new_shape).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=attention_similarity,
+ dropout=0.0,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(
+ batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
+ ).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+def rotate_pairwise(x):
+ """
+ pairwise rotation of the hidden dims of the input. Differerent from Llama Half-Tensor Rotation.
+
+ This is an optimized version of the following more explicit implementation:
+ ```python
+ x_rotated = torch.zeros_like(x, dtype=x.dtype, device=x.device)
+ x_rotated[..., ::2] = -x[..., 1::2]
+ x_rotated[..., 1::2] = x[..., ::2]
+ return x_rotated
+ ```
+ """
+ x = x.view(*x.shape[:-1], -1, 2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return x.flatten(start_dim=-2)
+
+
+def apply_rotary_pos_emb_2d_self_attn(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary position embedding to query and key tensors for self-attention.
+
+ Args:
+ q: Query tensor of shape (..., seq_len, head_dim)
+ k: Key tensor of shape (..., seq_len, head_dim)
+ cos: Cosine position embedding of shape (seq_len, head_dim)
+ sin: Sine position embedding of shape (seq_len, head_dim)
+
+ Returns:
+ Rotated (q, k) tensors
+ """
+ # Apply RoPE to queries
+ q_embed = q.float() # force upscale to float32 as in the original implementation
+ q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)
+
+ # Apply RoPE to keys (same embeddings as queries for self-attention)
+ k_embed = k.float() # force upscale to float32 as in the original implementation
+ k_embed = (k_embed * cos) + (rotate_pairwise(k_embed) * sin)
+
+ return q_embed.type_as(q), k_embed.type_as(k)
+
+
+class EdgeTamVideoRoPESelfAttention(nn.Module):
+ """Self-attention with rotary position encoding."""
+
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.memory_attention_hidden_size
+ self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate
+ self.num_attention_heads = config.memory_attention_num_attention_heads
+ self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
+ self.dropout_p = config.memory_attention_rope_dropout
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tensor:
+ # Input projections
+ batch_size, point_batch_size = query.shape[:2]
+ new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
+
+ query = self.q_proj(query).view(*new_shape).transpose(1, 2)
+ key = self.k_proj(key).view(*new_shape).transpose(1, 2)
+ value = self.v_proj(value).view(*new_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ # Apply rotary position encoding for self-attention
+ query, key = apply_rotary_pos_emb_2d_self_attn(query, key, cos=cos, sin=sin)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ dropout=0.0 if not self.training else self.dropout_p,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(
+ batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
+ ).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+def apply_rotary_pos_emb_2d_cross_attn(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ cos_k: torch.Tensor,
+ sin_k: torch.Tensor,
+ num_k_exclude_rope: int = 0,
+ repeat_freqs_k: int = 1,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary position embedding to query and key tensors for cross-attention.
+
+ Args:
+ q: Query tensor of shape (..., seq_len, head_dim)
+ k: Key tensor of shape (..., seq_len, head_dim)
+ cos: Cosine position embedding of shape (seq_len, head_dim)
+ sin: Sine position embedding of shape (seq_len, head_dim)
+ cos_k: Cosine position embedding for keys of shape (seq_len, head_dim)
+ sin_k: Sine position embedding for keys of shape (seq_len, head_dim)
+ num_k_exclude_rope: Number of tokens at end of k to exclude from RoPE (e.g., object pointer tokens)
+ repeat_freqs_k: Frequency repetition for keys in cross-attention (e.g., for spatial memory tokens)
+
+ Returns:
+ Rotated (q, k) tensors
+ """
+ # Apply RoPE to queries (always straightforward)
+ q_embed = q.float()
+ q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)
+
+ # Split keys: RoPE tokens and excluded tokens (e.g., object pointers)
+ num_total_k_tokens = k.shape[-2]
+ k_for_rope = k[..., : num_total_k_tokens - num_k_exclude_rope, :]
+ k_excluded = k[..., num_total_k_tokens - num_k_exclude_rope :, :]
+
+ # Early return if no keys need RoPE
+ if k_for_rope.shape[-2] == 0:
+ return q_embed.type_as(q), k_excluded
+
+ batch_size, num_heads, k_seq_len, channels_per_head = k_for_rope.shape
+
+ # Handle temporal/spatial token structure for memory
+ # Keys have temporal + spatial structure, only spatial tokens get RoPE
+ tokens_per_group = k_seq_len // repeat_freqs_k
+ spatial_tokens = cos_k.shape[-2]
+ temporal_tokens = tokens_per_group - spatial_tokens
+
+ # Reshape and separate temporal/spatial tokens
+ k_grouped = k_for_rope.view(batch_size, num_heads, repeat_freqs_k, tokens_per_group, channels_per_head)
+ k_temporal = k_grouped[..., :temporal_tokens, :].reshape(batch_size, num_heads, -1, channels_per_head)
+ k_spatial = k_grouped[..., temporal_tokens:, :].reshape(batch_size, num_heads, -1, channels_per_head)
+
+ # Only apply RoPE to spatial tokens
+ k_rope_input = k_spatial
+
+ # Prepare position embeddings for repeated groups
+ if repeat_freqs_k > 1:
+ cos_k = cos_k.repeat(1, 1, repeat_freqs_k, 1)
+ sin_k = sin_k.repeat(1, 1, repeat_freqs_k, 1)
+
+ # Apply RoPE to spatial tokens
+ k_spatial_embed = k_rope_input.float()
+ k_spatial_embed = (k_spatial_embed * cos_k) + (rotate_pairwise(k_spatial_embed) * sin_k)
+
+ # Reconstruct: temporal + spatial tokens back to original structure
+ k_spatial_reshaped = k_spatial_embed.view(batch_size, num_heads, repeat_freqs_k, -1, channels_per_head)
+ k_temporal_reshaped = k_temporal.view(batch_size, num_heads, repeat_freqs_k, -1, channels_per_head)
+ k_final = torch.cat([k_temporal_reshaped, k_spatial_reshaped], dim=3)
+ k_final = k_final.view(batch_size, num_heads, k_seq_len, channels_per_head)
+
+ # Combine RoPE-processed keys with excluded tokens
+ k_embed = torch.cat([k_final.type_as(k), k_excluded], dim=-2)
+ return q_embed.type_as(q), k_embed
+
+
+class EdgeTamVideoRoPECrossAttention(nn.Module):
+ """Cross-attention with rotary position encoding."""
+
+ def __init__(self, config: EdgeTamVideoConfig, kv_in_dim: int):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.memory_attention_hidden_size
+ self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate
+ self.num_attention_heads = config.memory_attention_num_attention_heads
+ self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.kv_in_dim = kv_in_dim
+
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
+ self.dropout_p = config.memory_attention_rope_dropout
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ position_embeddings_k: tuple[torch.Tensor, torch.Tensor],
+ num_k_exclude_rope: int = 0,
+ rope_k_repeat: int = 0,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tensor:
+ # Input projections
+ batch_size, point_batch_size = query.shape[:2]
+ new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
+
+ query = self.q_proj(query).view(*new_shape).transpose(1, 2)
+ key = self.k_proj(key).view(*new_shape).transpose(1, 2)
+ value = self.v_proj(value).view(*new_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ cos_k, sin_k = position_embeddings_k
+ # Apply rotary position encoding for cross-attention
+ query, key = apply_rotary_pos_emb_2d_cross_attn(
+ query,
+ key,
+ cos=cos,
+ sin=sin,
+ cos_k=cos_k,
+ sin_k=sin_k,
+ repeat_freqs_k=rope_k_repeat,
+ num_k_exclude_rope=num_k_exclude_rope,
+ )
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ dropout=0.0 if not self.training else self.dropout_p,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(
+ batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
+ ).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class EdgeTamVideoTwoWayAttentionBlock(nn.Module):
+ def __init__(self, config: EdgeTamVideoMaskDecoderConfig, skip_first_layer_pe: bool = False):
+ """
+ A transformer block with four layers:
+ (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
+ sparse inputs (4) cross attention of dense inputs -> sparse inputs
+
+ Arguments:
+ config (`EdgeTamVideoMaskDecoderConfig`):
+ The configuration file used to instantiate the block
+ attention_downsample_rate (*optionalk*, int, defaults to 2):
+ The downsample ratio of the block used to reduce the inner dim of the attention.
+ skip_first_layer_pe (*optional*, bool, defaults to `False`):
+ Whether or not to skip the addition of the query_point_embedding on the first layer.
+ """
+ super().__init__()
+ self.self_attn = EdgeTamVideoAttention(config, downsample_rate=1)
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size)
+
+ self.cross_attn_token_to_image = EdgeTamVideoAttention(config)
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size)
+
+ self.mlp = EdgeTamVideoFeedForward(
+ config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
+ )
+ self.layer_norm3 = nn.LayerNorm(config.hidden_size)
+
+ self.layer_norm4 = nn.LayerNorm(config.hidden_size)
+ self.cross_attn_image_to_token = EdgeTamVideoAttention(config)
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self,
+ queries: Tensor,
+ keys: Tensor,
+ query_point_embedding: Tensor,
+ key_point_embedding: Tensor,
+ attention_similarity: Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ):
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries, _ = self.self_attn(query=queries, key=queries, value=queries)
+ else:
+ query = queries + query_point_embedding
+ attn_out, _ = self.self_attn(query=query, key=query, value=queries)
+ queries = queries + attn_out
+ queries = self.layer_norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ query = queries + query_point_embedding
+ key = keys + key_point_embedding
+
+ attn_out, _ = self.cross_attn_token_to_image(
+ query=query, key=key, value=keys, attention_similarity=attention_similarity
+ )
+ queries = queries + attn_out
+
+ queries = self.layer_norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.layer_norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ query = queries + query_point_embedding
+ key = keys + key_point_embedding
+
+ attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
+ keys = keys + attn_out
+
+ keys = self.layer_norm4(keys)
+ return queries, keys, attn_out
+
+
+# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
+class EdgeTamVideoPositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
+ need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
+ ):
+ super().__init__()
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.scale = 2 * math.pi if scale is None else scale
+
+ @compile_compatible_method_lru_cache(maxsize=2)
+ def forward(
+ self,
+ shape: torch.Size,
+ device: Union[torch.device, str],
+ dtype: torch.dtype,
+ mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ if mask is None:
+ mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
+ not_mask = (~mask).to(dtype)
+ y_embed = not_mask.cumsum(1)
+ x_embed = not_mask.cumsum(2)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class EdgeTamVideoMemoryFuser(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [EdgeTamVideoMemoryFuserCXBlock(config) for _ in range(config.memory_fuser_num_layers)]
+ )
+
+ def forward(self, hidden_states):
+ # normally hidden_states: (N, C, H, W)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+class EdgeTamVideoMaskDownSamplerLayer(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig, in_channels: int, out_channels: int):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=config.mask_downsampler_kernel_size,
+ stride=config.mask_downsampler_stride,
+ padding=config.mask_downsampler_padding,
+ )
+ self.layer_norm = EdgeTamVideoLayerNorm(out_channels, eps=1e-6, data_format="channels_first")
+ self.activation = ACT2FN[config.mask_downsampler_hidden_act]
+
+ def forward(self, x):
+ return self.activation(self.layer_norm(self.conv(x)))
+
+
+class EdgeTamVideoMaskDownSampler(nn.Module):
+ """
+ Progressively downsample a mask by total_stride, each time by stride.
+ Note that LayerNorm is applied per *token*, like in ViT.
+
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
+ In the end, we linearly project to embed_dim channels.
+ """
+
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+
+ num_layers = int(math.log2(config.mask_downsampler_total_stride) // math.log2(config.mask_downsampler_stride))
+
+ self.layers = nn.ModuleList()
+ self.activation = ACT2FN[config.mask_downsampler_hidden_act]
+ mask_in_chans, mask_out_chans = 1, 1
+ for _ in range(num_layers):
+ mask_out_chans = mask_in_chans * (config.mask_downsampler_stride**2)
+ self.layers.append(EdgeTamVideoMaskDownSamplerLayer(config, mask_in_chans, mask_out_chans))
+ mask_in_chans = mask_out_chans
+
+ self.final_conv = nn.Conv2d(mask_out_chans, config.mask_downsampler_embed_dim, kernel_size=1)
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = layer(x)
+ x = self.final_conv(x)
+ return x
+
+
+class EdgeTamVideoMemoryEncoder(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+
+ hidden_size = config.memory_encoder_hidden_size
+ output_channels = config.memory_encoder_output_channels
+ self.mask_downsampler = EdgeTamVideoMaskDownSampler(config)
+ self.feature_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
+ self.memory_fuser = EdgeTamVideoMemoryFuser(config)
+ self.position_encoding = EdgeTamVideoPositionEmbeddingSine(num_pos_feats=output_channels // 2, normalize=True)
+ self.projection = nn.Conv2d(hidden_size, output_channels, kernel_size=1)
+
+ def forward(
+ self,
+ vision_features: torch.Tensor,
+ masks: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ ## Process masks
+ masks = self.mask_downsampler(masks)
+ ## Fuse pixel_features and downsampled masks
+
+ vision_features = self.feature_projection(vision_features)
+ vision_features = vision_features + masks
+ vision_features = self.memory_fuser(vision_features)
+ vision_features = self.projection(vision_features)
+
+ vision_pos_enc = self.position_encoding(vision_features.shape, vision_features.device, vision_features.dtype)
+
+ return vision_features, vision_pos_enc
+
+
+class EdgeTamVideoFeedForward(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ activation: str = "relu",
+ sigmoid_output: bool = False,
+ ):
+ super().__init__()
+ self.num_layers = num_layers
+ self.activation = ACT2FN[activation]
+ self.proj_in = nn.Linear(input_dim, hidden_dim)
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
+ self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ for layer in self.layers:
+ hidden_states = self.activation(layer(hidden_states))
+
+ hidden_states = self.proj_out(hidden_states)
+ if self.sigmoid_output:
+ hidden_states = F.sigmoid(hidden_states)
+ return hidden_states
+
+
+@auto_docstring
+class EdgeTamVideoPreTrainedModel(PreTrainedModel):
+ config_class = EdgeTamVideoConfig
+ base_model_prefix = "edgetam_video"
+ main_input_name = "pixel_values"
+ _supports_sdpa = True
+ _supports_flash_attn_2 = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, (nn.LayerNorm, EdgeTamVideoLayerNorm)):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+ elif isinstance(module, EdgeTamVideoModel):
+ if module.no_memory_positional_encoding is not None:
+ module.no_memory_positional_encoding.data.zero_()
+ if module.memory_temporal_positional_encoding is not None:
+ module.memory_temporal_positional_encoding.data.zero_()
+ if module.no_object_pointer is not None:
+ module.no_object_pointer.data.zero_()
+ if module.occlusion_spatial_embedding_parameter is not None:
+ module.occlusion_spatial_embedding_parameter.data.zero_()
+ if isinstance(module, EdgeTamVideoMemoryFuserCXBlock):
+ if module.scale is not None:
+ module.scale.data.zero_()
+
+
+class EdgeTamVideoInferenceCache:
+ """Cache for vision features and model constants."""
+
+ def __init__(
+ self,
+ inference_device: Union[torch.device, str] = "cpu",
+ inference_state_device: Union[torch.device, str] = "cpu",
+ max_vision_features_cache_size: int = 1,
+ ):
+ self.inference_device = inference_device
+ self.inference_state_device = inference_state_device
+ self.max_vision_features_cache_size = max_vision_features_cache_size
+
+ self._vision_features = {}
+
+ def cache_vision_features(self, frame_idx: int, features: dict):
+ """Cache vision features with automatic device management."""
+ cached = {}
+ if len(self._vision_features) >= self.max_vision_features_cache_size:
+ # remove the oldest frame
+ self._vision_features.pop(min(self._vision_features.keys()))
+
+ for key, value in features.items():
+ if isinstance(value, torch.Tensor):
+ cached[key] = value.to(self.inference_state_device, non_blocking=True)
+ elif isinstance(value, (list, tuple)) and value and isinstance(value[0], torch.Tensor):
+ cached[key] = [v.to(self.inference_state_device, non_blocking=True) for v in value]
+ else:
+ cached[key] = value
+ self._vision_features[frame_idx] = cached
+
+ def get_vision_features(self, frame_idx: int) -> Optional[dict]:
+ """Get cached vision features, automatically moved to inference device."""
+ if frame_idx not in self._vision_features:
+ return None
+
+ cached = self._vision_features[frame_idx]
+ moved = {}
+ for key, value in cached.items():
+ if isinstance(value, torch.Tensor):
+ moved[key] = value.to(self.inference_device, non_blocking=True)
+ elif isinstance(value, (list, tuple)) and value and isinstance(value[0], torch.Tensor):
+ moved[key] = [v.to(self.inference_device, non_blocking=True) for v in value]
+ else:
+ moved[key] = value
+ return moved
+
+ def clear_all(self):
+ """Clear all cached data."""
+ self._vision_features.clear()
+
+
+class EdgeTamVideoInferenceSession:
+ r"""
+ Manages video inference session parameters, state and cache.
+
+ Args:
+ video (`torch.FloatTensor`, *optional*):
+ The video to process. No need to provide when streaming.
+ video_height (`int`, *optional*):
+ The height of the video.
+ video_width (`int`, *optional*):
+ The width of the video.
+ inference_device (`torch.device`, *optional*, defaults to `"cpu"`):
+ The device to use for inference.
+ inference_state_device (`torch.device`, *optional*, defaults to `"cpu"`):
+ The device to store the inference state on.
+ video_storage_device (`torch.device`, *optional*, defaults to `"cpu"`):
+ The device to store the video on.
+ dtype (`torch.dtype`, *optional*, defaults to `"float32"`):
+ The dtype to use for the video.
+ max_vision_features_cache_size (`int`, *optional*, defaults to 1):
+ The maximum number of vision features to cache.
+ """
+
+ def __init__(
+ self,
+ video: Optional[torch.FloatTensor] = None,
+ video_height: Optional[int] = None,
+ video_width: Optional[int] = None,
+ inference_device: Union[torch.device, str] = "cpu",
+ inference_state_device: Union[torch.device, str] = "cpu",
+ video_storage_device: Union[torch.device, str] = "cpu",
+ dtype: Union[torch.dtype, str] = "float32",
+ max_vision_features_cache_size: int = 1,
+ ):
+ # store as a dictionary to avoid double memory allocation with torch.cat when adding new frames
+ self.processed_frames = (
+ dict(enumerate(video.to(video_storage_device, dtype=dtype))) if video is not None else None
+ )
+ self.video_height = video_height
+ self.video_width = video_width
+
+ self.inference_device = inference_device
+ self.inference_state_device = inference_state_device
+ self.video_storage_device = video_storage_device
+ self.dtype = dtype
+ self.max_vision_features_cache_size = max_vision_features_cache_size
+
+ # Cache for computed features
+ self.cache = EdgeTamVideoInferenceCache(
+ inference_device=self.inference_device,
+ inference_state_device=self.inference_state_device,
+ max_vision_features_cache_size=self.max_vision_features_cache_size,
+ )
+
+ # Persistent object tracking state
+ self._obj_id_to_idx = OrderedDict()
+ self._obj_idx_to_id = OrderedDict()
+ self.obj_ids = []
+
+ # Persistent user inputs
+ self.point_inputs_per_obj = {}
+ self.mask_inputs_per_obj = {}
+
+ # Persistent model outputs/history
+ self.output_dict_per_obj = {}
+ self.frames_tracked_per_obj = {}
+
+ # Session state flags
+ self.obj_with_new_inputs = []
+
+ @property
+ def num_frames(self) -> Optional[int]:
+ return len(self.processed_frames) if self.processed_frames is not None else None
+
+ # Object management
+ def obj_id_to_idx(self, obj_id: int) -> int:
+ """Map object ID to index, creating new entry if needed."""
+ obj_idx = self._obj_id_to_idx.get(obj_id, None)
+ if obj_idx is not None:
+ return obj_idx
+
+ obj_idx = len(self._obj_id_to_idx)
+ self._obj_id_to_idx[obj_id] = obj_idx
+ self._obj_idx_to_id[obj_idx] = obj_id
+ self.obj_ids = list(self._obj_id_to_idx)
+
+ self.point_inputs_per_obj[obj_idx] = {}
+ self.mask_inputs_per_obj[obj_idx] = {}
+ self.output_dict_per_obj[obj_idx] = {
+ "cond_frame_outputs": {},
+ "non_cond_frame_outputs": {},
+ }
+ self.frames_tracked_per_obj[obj_idx] = {}
+
+ return obj_idx
+
+ # Video Inference specific functions
+ def obj_idx_to_id(self, obj_idx: int) -> int:
+ """Map model-side object index to client-side object id."""
+ return self._obj_idx_to_id[obj_idx]
+
+ def get_obj_num(self) -> int:
+ """Get the total number of unique object ids received so far in this session."""
+ return len(self._obj_idx_to_id)
+
+ # Input management with device handling
+ def add_point_inputs(self, obj_idx: int, frame_idx: int, inputs: dict):
+ """Add point inputs with automatic device placement."""
+ device_inputs = {}
+ for key, value in inputs.items():
+ if isinstance(value, torch.Tensor):
+ device_inputs[key] = value.to(self.inference_device, non_blocking=True)
+ else:
+ device_inputs[key] = value
+ self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs
+
+ def remove_point_inputs(self, obj_idx: int, frame_idx: int):
+ """Remove point inputs."""
+ self.point_inputs_per_obj[obj_idx].pop(frame_idx, None)
+
+ def add_mask_inputs(self, obj_idx: int, frame_idx: int, inputs: torch.Tensor):
+ """Add mask inputs with automatic device placement."""
+ self.mask_inputs_per_obj[obj_idx][frame_idx] = inputs.to(
+ self.inference_device, dtype=self.dtype, non_blocking=True
+ )
+
+ def remove_mask_inputs(self, obj_idx: int, frame_idx: int):
+ """Remove mask inputs."""
+ self.mask_inputs_per_obj[obj_idx].pop(frame_idx, None)
+
+ # Output management with smart device placement
+ def store_output(
+ self,
+ obj_idx: int,
+ frame_idx: int,
+ output_key: Optional[str] = None,
+ output_value: Optional[Union[torch.Tensor, dict]] = None,
+ is_conditioning_frame: bool = True,
+ ):
+ """
+ Store output with smart device management.
+ If output_key is None, the output is stored as a dictionary.
+
+ Args:
+ obj_idx (int): The index of the object.
+ frame_idx (int): The index of the frame.
+ output_key (Optional[str]): The key of the output. If None, the output is stored as a dictionary.
+ output_value (Optional[Union[torch.Tensor, dict]]): The value of the output.
+ is_conditioning_frame (bool): Whether the output is for a conditioning frame.
+ """
+ storage_key = "cond_frame_outputs" if is_conditioning_frame else "non_cond_frame_outputs"
+
+ if output_key is None and isinstance(output_value, dict):
+ self.output_dict_per_obj[obj_idx][storage_key][frame_idx] = {}
+ for key, value in output_value.items():
+ self.store_output(obj_idx, frame_idx, key, value, is_conditioning_frame)
+ return
+
+ # Device placement: small tensors stay on inference device, large ones go to inference state device
+ if output_key in ["object_pointer", "object_score_logits"]: # Small tensors
+ self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value
+ elif isinstance(output_value, torch.Tensor): # Large tensors like masks, features
+ self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value.to(
+ self.inference_state_device, non_blocking=True
+ )
+ else:
+ self.output_dict_per_obj[obj_idx][storage_key][frame_idx][output_key] = output_value
+
+ def get_output(
+ self,
+ obj_idx: int,
+ frame_idx: int,
+ output_key: str,
+ is_conditioning_frame: bool = True,
+ ):
+ """
+ Get output with smart device management.
+
+ Args:
+ obj_idx (int): The index of the object.
+ frame_idx (int): The index of the frame.
+ output_key (str): The key of the output.
+ is_conditioning_frame (bool): Whether the output is for a conditioning frame.
+ """
+ storage_key = "cond_frame_outputs" if is_conditioning_frame else "non_cond_frame_outputs"
+ out = self.output_dict_per_obj[obj_idx][storage_key].get(frame_idx, None)
+ # move to inference device if needed
+ if out is None:
+ return None
+ value = out[output_key]
+ if isinstance(value, torch.Tensor):
+ value = value.to(self.inference_device, non_blocking=True)
+ return value
+
+ # Video frame management
+ def add_new_frame(self, pixel_values: torch.Tensor, frame_idx: Optional[int] = None) -> int:
+ """Add new frame with automatic device placement."""
+ pixel_values = pixel_values.to(self.video_storage_device, dtype=self.dtype, non_blocking=True)
+ if pixel_values.dim() == 4:
+ pixel_values = pixel_values.squeeze(0)
+
+ if frame_idx is None:
+ frame_idx = len(self.processed_frames) if self.processed_frames is not None else 0
+
+ if self.processed_frames is None:
+ self.processed_frames = {frame_idx: pixel_values}
+ else:
+ self.processed_frames[frame_idx] = pixel_values
+
+ return frame_idx
+
+ def get_frame(self, frame_idx: int) -> torch.Tensor:
+ """Get frame from video."""
+ return self.processed_frames[frame_idx].to(self.inference_device, non_blocking=True)
+
+ def reset_tracking_data(self):
+ """Reset tracking data but keep cache."""
+ self._obj_id_to_idx.clear()
+ self._obj_idx_to_id.clear()
+ self.obj_ids.clear()
+ self.point_inputs_per_obj.clear()
+ self.mask_inputs_per_obj.clear()
+ self.output_dict_per_obj.clear()
+ self.frames_tracked_per_obj.clear()
+ self.obj_with_new_inputs = []
+ # Note: cache and video data are preserved
+
+ def reset_inference_session(self):
+ """Reset tracking data and cache."""
+ self._obj_id_to_idx.clear()
+ self._obj_idx_to_id.clear()
+ self.obj_ids.clear()
+ self.point_inputs_per_obj.clear()
+ self.mask_inputs_per_obj.clear()
+ self.output_dict_per_obj.clear()
+ self.frames_tracked_per_obj.clear()
+ self.obj_with_new_inputs = []
+ self.cache.clear_all()
+
+
+class EdgeTamVideoMemoryAttentionMLP(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.memory_attention_hidden_size
+ self.intermediate_size = config.memory_attention_mlp_hidden_size
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size)
+ self.dropout = nn.Dropout(config.memory_attention_dropout)
+ self.act_fn = ACT2FN[config.memory_attention_mlp_hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.dropout(self.act_fn(self.up_proj(x))))
+
+
+class EdgeTamVideoMemoryAttentionLayer(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ hidden_size = config.memory_attention_hidden_size
+ self.self_attn = EdgeTamVideoRoPESelfAttention(config)
+ self.cross_attn_image = EdgeTamVideoRoPECrossAttention(config, kv_in_dim=64)
+
+ # MLP module
+ self.mlp = EdgeTamVideoMemoryAttentionMLP(config)
+
+ self.layer_norm1 = nn.LayerNorm(hidden_size)
+ self.layer_norm2 = nn.LayerNorm(hidden_size)
+ self.layer_norm3 = nn.LayerNorm(hidden_size)
+ self.dropout1 = nn.Dropout(config.memory_attention_dropout)
+ self.dropout2 = nn.Dropout(config.memory_attention_dropout)
+ self.dropout3 = nn.Dropout(config.memory_attention_dropout)
+
+ def forward(
+ self,
+ queries: Tensor,
+ keys: Tensor,
+ key_point_embedding: Tensor,
+ rope_position_embeddings: tuple[Tensor, Tensor],
+ rope_position_embeddings_k: Optional[tuple[Tensor, Tensor]] = None,
+ num_k_exclude_rope: int = 0,
+ rope_k_repeat: int = 0,
+ ) -> torch.Tensor:
+ # Self-Attention
+ query = self.layer_norm1(queries)
+ query, _ = self.self_attn(query=query, key=query, value=query, position_embeddings=rope_position_embeddings)
+ queries = queries + self.dropout1(query)
+
+ # Cross-Attention
+ query = self.layer_norm2(queries)
+ query, _ = self.cross_attn_image(
+ query=query,
+ key=keys + key_point_embedding,
+ value=keys,
+ position_embeddings=rope_position_embeddings,
+ position_embeddings_k=rope_position_embeddings_k,
+ num_k_exclude_rope=num_k_exclude_rope,
+ rope_k_repeat=rope_k_repeat,
+ )
+ queries = queries + self.dropout2(query)
+ # MLP
+ query = self.layer_norm3(queries)
+ query = self.mlp(query)
+ queries = queries + self.dropout3(query)
+ return queries
+
+
+class EdgeTamVideoMemoryAttention(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [EdgeTamVideoMemoryAttentionLayer(config) for _ in range(config.memory_attention_num_layers)]
+ )
+ self.layer_norm = nn.LayerNorm(config.memory_attention_hidden_size)
+ self.rotary_emb = EdgeTamVideoVisionRotaryEmbedding(config=config)
+ self.rotary_emb_k = EdgeTamVideoVisionRotaryEmbedding(
+ config, end_x=config.memory_attention_rope_k_sizes[0], end_y=config.memory_attention_rope_k_sizes[1]
+ )
+
+ def forward(
+ self,
+ current_vision_features: torch.Tensor,
+ memory: torch.Tensor,
+ current_vision_position_embeddings: Optional[Tensor] = None,
+ memory_posision_embeddings: Optional[Tensor] = None,
+ num_object_pointer_tokens: int = 0,
+ num_spatial_memory_tokens: int = -1,
+ ):
+ """
+ Args:
+ current_vision_features (`torch.FloatTensor`):
+ The current vision features used for self-attention.
+ memory (`torch.FloatTensor`):
+ The memory features used for cross-attention.
+ current_vision_position_embeddings (`torch.FloatTensor`, *optional*):
+ The position embeddings for the current vision features.
+ memory_posision_embeddings (`torch.FloatTensor`, *optional*):
+ The position embeddings for the memory features.
+ num_object_pointer_tokens (`int`, *optional*, defaults to 0):
+ The number of object pointer tokens.
+ """
+ output = current_vision_features
+ if current_vision_position_embeddings is not None:
+ output = output + 0.1 * current_vision_position_embeddings
+
+ # Convert to batch first
+ output = output.transpose(0, 1)
+ memory = memory.transpose(0, 1).unsqueeze(1)
+ memory_posision_embeddings = memory_posision_embeddings.transpose(0, 1).unsqueeze(1)
+ rope_position_embeddings = self.rotary_emb()
+ rope_position_embeddings_k = self.rotary_emb_k()
+ for layer in self.layers:
+ output = layer(
+ queries=output.unsqueeze(1) if output.ndim == 3 else output,
+ keys=memory,
+ key_point_embedding=memory_posision_embeddings,
+ rope_position_embeddings=rope_position_embeddings,
+ rope_position_embeddings_k=rope_position_embeddings_k,
+ num_k_exclude_rope=num_object_pointer_tokens,
+ rope_k_repeat=num_spatial_memory_tokens,
+ )
+
+ normed_output = self.layer_norm(output)
+
+ # Convert back to seq first
+ normed_output = normed_output.transpose(0, 1)
+
+ return normed_output
+
+
+class EdgeTamVideoPerceiverMLP(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.hidden_size = config.perceiver_resampler_hidden_size
+ self.intermediate_size = config.perceiver_resampler_mlp_intermediate_size
+
+ self.layer_norm = nn.LayerNorm(self.hidden_size)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = nn.GELU()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.down_proj(self.act_fn(self.up_proj(hidden_states)))
+ return hidden_states
+
+
+class EdgeTamVideoPerceiverAttention(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.perceiver_resampler_hidden_size
+ self.num_attention_heads = config.perceiver_resampler_num_attention_heads
+ self.head_dim = config.perceiver_resampler_attention_head_dim
+ self.attention_dropout = config.perceiver_resampler_attention_dropout
+
+ self.inner_dim = self.head_dim * self.num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
+ self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ positional_encoding: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ # Project queries, keys, and values
+ query = self.q_proj(query)
+ key = self.k_proj(key)
+ value = self.v_proj(value)
+
+ # Reshape for multi-head attention
+ batch_size, seq_len_q = query.shape[:2]
+ query = query.view(batch_size, seq_len_q, self.num_attention_heads, self.head_dim).transpose(1, 2)
+ seq_len_kv = key.shape[1]
+ key = key.view(batch_size, seq_len_kv, self.num_attention_heads, self.head_dim).transpose(1, 2)
+ value = value.view(batch_size, seq_len_kv, self.num_attention_heads, self.head_dim).transpose(1, 2)
+
+ # Add positional encoding if provided
+ if positional_encoding is not None:
+ pos_encoding = positional_encoding.view(
+ batch_size, seq_len_kv, self.num_attention_heads, self.head_dim
+ ).transpose(1, 2)
+ key = key + pos_encoding
+ value = value + pos_encoding
+
+ # Apply attention
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, _ = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+
+ # Reshape output
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.inner_dim)
+ return self.o_proj(attn_output)
+
+
+class EdgeTamVideoPerceiverEncoderLayer(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+
+ self.cross_attention = EdgeTamVideoPerceiverAttention(config)
+ self.mlp = EdgeTamVideoPerceiverMLP(config)
+ self.dropout = nn.Dropout(config.perceiver_resampler_hidden_dropout)
+
+ self.self_attention = EdgeTamVideoPerceiverAttention(config)
+ self.self_mlp = EdgeTamVideoPerceiverMLP(config)
+
+ # Layer norms moved from attention classes to here
+ self.layer_norm_input = nn.LayerNorm(config.perceiver_resampler_hidden_size)
+ self.layer_norm_latents = nn.LayerNorm(config.perceiver_resampler_hidden_size)
+ self.layer_norm_self = nn.LayerNorm(config.perceiver_resampler_hidden_size)
+
+ def forward(
+ self,
+ latents: torch.Tensor,
+ input_features: torch.Tensor,
+ positional_encoding: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # Cross attention with layer norms
+ normalized_latents = self.layer_norm_latents(latents)
+ normalized_input = self.layer_norm_input(input_features)
+ cross_attention_output = self.cross_attention(
+ query=normalized_latents,
+ key=normalized_input,
+ value=normalized_input,
+ positional_encoding=positional_encoding,
+ )
+ latents = latents + self.dropout(cross_attention_output)
+
+ mlp_output = self.mlp(latents)
+ latents = latents + mlp_output
+
+ # Self attention with layer norm
+ normalized_latents_self = self.layer_norm_self(latents)
+ self_attention_output = self.self_attention(
+ query=normalized_latents_self, key=normalized_latents_self, value=normalized_latents_self
+ )
+ latents = latents + self_attention_output
+
+ self_mlp_output = self.self_mlp(latents)
+ latents = latents + self_mlp_output
+
+ return latents
+
+
+def window_partition(hidden_state, window_size):
+ """
+ Partition into non-overlapping windows with padding if needed.
+
+ Args:
+ hidden_state (`torch.Tensor`):
+ Input tokens with [batch_size, height, width, num_channels].
+ window_size (`int`):
+ Window size.
+
+ Returns:
+ `tuple(torch.FloatTensor)` comprising various elements:
+ - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
+ - (padded_height, padded_width): padded height and width before partition
+ """
+ batch_size, height, width, num_channels = hidden_state.shape
+
+ pad_height = (window_size - height % window_size) % window_size
+ pad_width = (window_size - width % window_size) % window_size
+
+ # Noop in case pad_width == 0 and pad_height == 0.
+ hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
+
+ padded_height, padded_width = height + pad_height, width + pad_width
+
+ hidden_state = hidden_state.view(
+ batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
+ )
+ windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
+ return windows, (padded_height, padded_width)
+
+
+class EdgeTamVideoPerceiverResampler(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.perceiver_resampler_hidden_size
+ self.num_latents_1d = config.perceiver_resampler_num_latents
+ self.num_latents_2d = config.perceiver_resampler_num_latents_2d
+ self.num_layers = config.perceiver_resampler_num_layers
+
+ if self.num_latents_1d > 0:
+ self.latents_1d = nn.Parameter(torch.randn(self.num_latents_1d, self.hidden_size))
+ if self.num_latents_2d > 0:
+ self.latents_2d = nn.Parameter(torch.randn(self.num_latents_2d, self.hidden_size))
+
+ self.positional_encoding = EdgeTamVideoPositionEmbeddingSine(
+ num_pos_feats=self.hidden_size // 2, normalize=True
+ )
+
+ self.layers = nn.ModuleList([EdgeTamVideoPerceiverEncoderLayer(config) for _ in range(self.num_layers)])
+
+ self.layer_norm = nn.LayerNorm(self.hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ positional_encoding: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ output_latents = []
+ output_positional_encodings = []
+
+ if self.num_latents_1d > 0:
+ latents_1d, pos_1d = self._forward_1d(hidden_states, positional_encoding)
+ output_latents.append(latents_1d)
+ output_positional_encodings.append(pos_1d)
+
+ if self.num_latents_2d > 0:
+ latents_2d, pos_2d = self._forward_2d(hidden_states)
+ output_latents.append(latents_2d)
+ output_positional_encodings.append(pos_2d)
+
+ combined_latents = torch.cat(output_latents, dim=1)
+
+ combined_positional_encoding = None
+ if positional_encoding is not None and output_positional_encodings:
+ combined_positional_encoding = torch.cat(output_positional_encodings, dim=1)
+
+ return combined_latents, combined_positional_encoding
+
+ def _forward_1d(
+ self,
+ hidden_states: torch.Tensor,
+ positional_encoding: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ batch_size = hidden_states.shape[0]
+
+ latents = self.latents_1d.unsqueeze(0).expand(batch_size, -1, -1)
+ flattened_features = hidden_states.permute(0, 2, 3, 1).flatten(1, 2)
+
+ positional_features = None
+ if positional_encoding is not None:
+ positional_features = positional_encoding.permute(0, 2, 3, 1).flatten(1, 2)
+
+ for layer in self.layers:
+ latents = layer(latents, flattened_features, positional_features)
+
+ latents = self.layer_norm(latents)
+
+ output_positional_encoding = None
+ if positional_encoding is not None:
+ output_positional_encoding = torch.zeros_like(latents)
+
+ return latents, output_positional_encoding
+
+ def _forward_2d(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size, channels, height, width = hidden_states.shape
+
+ latents_2d = self.latents_2d.unsqueeze(0).expand(batch_size, -1, -1).view(-1, 1, channels)
+
+ num_windows_per_dim = int(math.sqrt(self.num_latents_2d))
+ window_size = height // num_windows_per_dim
+
+ windowed_input = hidden_states.permute(0, 2, 3, 1)
+ windowed_features, _ = window_partition(windowed_input, window_size)
+ windowed_features = windowed_features.flatten(1, 2)
+
+ for layer in self.layers:
+ latents_2d = layer(latents_2d, windowed_features, positional_encoding=None)
+
+ latents_2d = latents_2d.view(batch_size, num_windows_per_dim, num_windows_per_dim, channels).permute(
+ 0, 3, 1, 2
+ )
+
+ positional_encoding_2d = self.positional_encoding(latents_2d.shape, latents_2d.device, latents_2d.dtype).to(
+ dtype=hidden_states.dtype
+ )
+ positional_encoding_2d = positional_encoding_2d.permute(0, 2, 3, 1).flatten(1, 2)
+
+ latents_2d = latents_2d.permute(0, 2, 3, 1).flatten(1, 2)
+ latents_2d = self.layer_norm(latents_2d)
+
+ return latents_2d, positional_encoding_2d
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the EdgeTamVideo model's output.")
+class EdgeTamVideoImageSegmentationOutput(ModelOutput):
+ r"""
+ iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
+ The Intersection over Union (IoU) scores of the predicted masks.
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
+ The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
+ by the processor to be brought to the original image size.
+ object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
+ Logits for the object score, indicating if an object is present.
+ image_embeddings (`tuple(torch.FloatTensor)`):
+ The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
+ tensor has shape `(batch_size, channels, height, width)`.
+ vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
+ Hidden-states of the vision model at the output of each stage.
+ vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights of the vision model.
+ mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
+ Attentions weights of the mask decoder.
+ high_res_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, image_size, image_size)`, *optional*):
+ The predicted masks, upscaled to the original image size. Only used for EdgeTamVideoModel.
+ object_pointer (`torch.FloatTensor` of shape `(batch_size, point_batch_size, hidden_size)`, *optional*):
+ A tensor representing the object pointer, used for tracking in videos. Only used for EdgeTamVideoModel.
+ """
+
+ iou_scores: Optional[torch.FloatTensor] = None
+ pred_masks: Optional[torch.FloatTensor] = None
+ object_score_logits: Optional[torch.FloatTensor] = None
+ image_embeddings: tuple[torch.FloatTensor, ...] = None
+ vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+ mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+ high_res_masks: Optional[torch.FloatTensor] = None
+ object_pointer: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(custom_intro="Base class for the Sam2 model's output.")
+class EdgeTamVideoSegmentationOutput(ModelOutput):
+ r"""
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
+ The predicted masks stored at the model's resolution.
+ frame_idx (`int`):
+ The frame index of the video.
+ """
+
+ pred_masks: Optional[torch.FloatTensor] = None
+ frame_idx: Optional[int] = None
+
+
+class EdgeTamVideoPositionalEmbedding(nn.Module):
+ def __init__(self, config: EdgeTamVideoPromptEncoderConfig):
+ super().__init__()
+ self.scale = config.scale
+ positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
+ self.register_buffer("positional_embedding", positional_embedding)
+
+ def forward(self, input_coords, input_shape=None):
+ """Positionally encode points that are normalized to [0,1]."""
+ coordinates = input_coords.clone()
+
+ if input_shape is not None:
+ coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
+ coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
+ coordinates.to(torch.float32)
+
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coordinates = 2 * coordinates - 1
+ coordinates = coordinates.to(self.positional_embedding.dtype)
+ coordinates = coordinates @ self.positional_embedding
+ coordinates = 2 * np.pi * coordinates
+ # outputs d_1 x ... x d_n x channel shape
+ return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
+
+
+class EdgeTamVideoMaskEmbedding(nn.Module):
+ def __init__(self, config: EdgeTamVideoPromptEncoderConfig):
+ super().__init__()
+ self.mask_input_channels = config.mask_input_channels // 4
+ self.activation = ACT2FN[config.hidden_act]
+ self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
+ self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
+ self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
+ self.layer_norm1 = EdgeTamVideoLayerNorm(
+ self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
+ )
+ self.layer_norm2 = EdgeTamVideoLayerNorm(
+ self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
+ )
+
+ def forward(self, masks):
+ hidden_states = self.conv1(masks)
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ dense_embeddings = self.conv3(hidden_states)
+ return dense_embeddings
+
+
+class EdgeTamVideoPromptEncoder(nn.Module):
+ def __init__(self, config: EdgeTamVideoPromptEncoderConfig):
+ super().__init__()
+ self.shared_embedding = EdgeTamVideoPositionalEmbedding(config)
+ self.mask_embed = EdgeTamVideoMaskEmbedding(config)
+ self.no_mask_embed = nn.Embedding(1, config.hidden_size)
+
+ self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
+ self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
+ self.input_image_size = config.image_size
+
+ self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
+ self.hidden_size = config.hidden_size
+ self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
+
+ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
+ labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
+ input_shape = (self.input_image_size, self.input_image_size)
+ point_embedding = self.shared_embedding(points, input_shape)
+
+ # torch.where and expanding the labels tensor is required by the ONNX export
+ point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
+
+ # This is required for the ONNX export. The dtype, device need to be explicitly
+ # specified as otherwise torch.onnx.export interprets as double
+ point_embedding = torch.where(
+ labels[..., None] != -10,
+ point_embedding,
+ torch.zeros_like(point_embedding),
+ )
+
+ # Add point embeddings for labels >= 0
+ point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
+
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes += 0.5 # Shift to center of pixel
+ coords = boxes.view(*boxes.shape[:2], 2, 2)
+ # add padding point for consistency with the original implementation
+ coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
+ corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
+ corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
+ corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
+ corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
+ return corner_embedding
+
+ def forward(
+ self,
+ input_points: Optional[tuple[torch.Tensor, torch.Tensor]],
+ input_labels: Optional[torch.Tensor],
+ input_boxes: Optional[torch.Tensor],
+ input_masks: Optional[torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense embeddings.
+
+ Args:
+ points (`torch.Tensor`, *optional*):
+ point coordinates and labels to embed.
+ boxes (`torch.Tensor`, *optional*):
+ boxes to embed
+ masks (`torch.Tensor`, *optional*):
+ masks to embed
+ """
+ sparse_embeddings = None
+ batch_size = 1
+ if input_points is not None:
+ batch_size = input_points.shape[0]
+ if input_labels is None:
+ raise ValueError("If points are provided, labels must also be provided.")
+ point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
+ sparse_embeddings = point_embeddings
+ if input_boxes is not None:
+ batch_size = input_boxes.shape[0]
+ box_embeddings = self._embed_boxes(input_boxes)
+ if sparse_embeddings is None:
+ sparse_embeddings = box_embeddings
+ else:
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
+ if input_masks is not None:
+ dense_embeddings = self.mask_embed(input_masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
+
+ return sparse_embeddings, dense_embeddings
+
+
+class EdgeTamVideoTwoWayTransformer(nn.Module):
+ def __init__(self, config: EdgeTamVideoMaskDecoderConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_hidden_layers = config.num_hidden_layers
+ self.layers = nn.ModuleList()
+
+ for i in range(self.num_hidden_layers):
+ self.layers.append(EdgeTamVideoTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
+
+ self.final_attn_token_to_image = EdgeTamVideoAttention(config)
+ self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ point_embeddings: Tensor,
+ image_embeddings: Tensor,
+ image_positional_embeddings: Tensor,
+ attention_similarity: Tensor,
+ target_embedding=None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, BaseModelOutput]:
+ if image_embeddings is None:
+ raise ValueError("You have to specify an image_embedding")
+
+ image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
+ image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
+
+ # Prepare queries
+ queries = point_embeddings
+ keys = image_embeddings
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ if target_embedding is not None:
+ queries += target_embedding
+
+ queries, keys, _ = layer(
+ queries=queries,
+ keys=keys,
+ query_point_embedding=point_embeddings,
+ key_point_embedding=image_positional_embeddings,
+ attention_similarity=attention_similarity,
+ **kwargs,
+ )
+ # Apply the final attention layer from the points to the image
+ query = queries + point_embeddings
+ key = keys + image_positional_embeddings
+
+ attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
+
+ queries = queries + attn_out
+ queries = self.layer_norm_final_attn(queries)
+ return queries, keys
+
+
+class EdgeTamVideoMaskDecoder(nn.Module):
+ def __init__(self, config: EdgeTamVideoMaskDecoderConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+
+ self.num_multimask_outputs = config.num_multimask_outputs
+ self.num_mask_tokens = config.num_multimask_outputs + 1
+
+ self.iou_token = nn.Embedding(1, self.hidden_size)
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
+
+ self.transformer = EdgeTamVideoTwoWayTransformer(config)
+
+ # should we create a new class for this?
+ self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
+ self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
+ self.upscale_layer_norm = EdgeTamVideoLayerNorm(self.hidden_size // 4, data_format="channels_first")
+ self.activation = nn.GELU()
+
+ mlps_list = []
+ for _ in range(self.num_mask_tokens):
+ mlps_list += [EdgeTamVideoFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
+ self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
+ self.iou_prediction_head = EdgeTamVideoFeedForward(
+ self.hidden_size,
+ config.iou_head_hidden_dim,
+ self.num_mask_tokens,
+ config.iou_head_depth,
+ sigmoid_output=True,
+ )
+
+ self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
+ self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
+
+ self.obj_score_token = nn.Embedding(1, self.hidden_size)
+ self.pred_obj_score_head = EdgeTamVideoFeedForward(self.hidden_size, self.hidden_size, 1, 3)
+
+ self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_positional_embeddings: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ high_resolution_features: list[torch.Tensor],
+ attention_similarity: Optional[torch.Tensor] = None,
+ target_embedding: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Args:
+ image_embeddings (`torch.Tensor`):
+ The embeddings from the image encoder.
+ image_positional_embeddings (`torch.Tensor`):
+ Positional encoding with the shape of image_embeddings.
+ sparse_prompt_embeddings (`torch.Tensor`):
+ The embeddings of the points and boxes.
+ dense_prompt_embeddings (`torch.Tensor`):
+ The embeddings of the mask inputs.
+ multimask_output (`bool`):
+ Whether to return multiple masks or a single mask.
+ high_resolution_features (`list[torch.Tensor]`, *optional*):
+ The high-resolution features from the vision encoder.
+ attention_similarity (`torch.Tensor`, *optional*):
+ The attention similarity tensor.
+ target_embedding (`torch.Tensor`, *optional*):
+ The target embedding.
+ """
+ batch_size, num_channels, height, width = image_embeddings.shape
+ point_batch_size = sparse_prompt_embeddings.shape[1]
+ # Concatenate output tokens
+ output_tokens = torch.cat(
+ [
+ self.obj_score_token.weight,
+ self.iou_token.weight,
+ self.mask_tokens.weight,
+ ],
+ dim=0,
+ )
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
+
+ if sparse_prompt_embeddings.shape[0] != 0:
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
+ else:
+ tokens = output_tokens
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
+
+ # Expand per-image data in batch direction to be per-mask
+ image_embeddings = image_embeddings + dense_prompt_embeddings
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
+ # Run the transformer
+ point_embeddings, image_embeddings = self.transformer(
+ point_embeddings=point_embeddings,
+ image_embeddings=image_embeddings,
+ image_positional_embeddings=image_positional_embeddings,
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ **kwargs,
+ )
+ iou_token_out = point_embeddings[:, :, 1, :]
+ mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ image_embeddings = image_embeddings.transpose(2, 3).view(
+ batch_size * point_batch_size, num_channels, height, width
+ )
+
+ feat_s0, feat_s1 = high_resolution_features
+ feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
+ feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
+ upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
+
+ hyper_in_list: list[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ current_mlp = self.output_hypernetworks_mlps[i]
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
+ hyper_in = torch.stack(hyper_in_list, dim=2)
+
+ _, num_channels, height, width = upscaled_embedding.shape
+ upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
+ masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+ object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ mask_slice = slice(1, None)
+ masks = masks[:, :, mask_slice, :, :]
+ iou_pred = iou_pred[:, :, mask_slice]
+ elif self.dynamic_multimask_via_stability and not self.training:
+ mask_slice = slice(0, 1)
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+ else:
+ mask_slice = slice(0, 1)
+ masks = masks[:, :, mask_slice, :, :]
+ iou_pred = iou_pred[:, :, mask_slice]
+
+ sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
+
+ return masks, iou_pred, sam_tokens_out, object_score_logits
+
+ def _get_stability_scores(self, mask_logits):
+ """
+ Compute stability scores of the mask logits based on the IoU between upper and
+ lower thresholds.
+ """
+ mask_logits = mask_logits.flatten(-2)
+ stability_delta = self.dynamic_multimask_stability_delta
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
+ return stability_scores
+
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+ """
+ When outputting a single mask, if the stability score from the current single-mask
+ output (based on output token 0) falls below a threshold, we instead select from
+ multi-mask outputs (based on output token 1~3) the mask with the highest predicted
+ IoU score. This is intended to ensure a valid mask for both clicking and tracking.
+ """
+ # The best mask from multimask output tokens (1~3)
+ multimask_logits = all_mask_logits[:, :, 1:, :, :]
+ multimask_iou_scores = all_iou_scores[:, :, 1:]
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
+ best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ best_scores_inds_expanded = best_scores_inds_expanded.expand(
+ -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
+ )
+ best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
+ best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
+
+ # The mask from singlemask output token 0 and its stability score
+ singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
+ singlemask_iou_scores = all_iou_scores[:, :, 0:1]
+ stability_scores = self._get_stability_scores(singlemask_logits)
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+
+ # Dynamically fall back to best multimask output upon low stability scores.
+ mask_logits_out = torch.where(
+ is_stable[..., None, None].expand_as(singlemask_logits),
+ singlemask_logits,
+ best_multimask_logits,
+ )
+ iou_scores_out = torch.where(
+ is_stable.expand_as(singlemask_iou_scores),
+ singlemask_iou_scores,
+ best_multimask_iou_scores,
+ )
+ return mask_logits_out, iou_scores_out
+
+
+# a large negative value as a placeholder score for missing objects
+NO_OBJ_SCORE = -1024.0
+
+
+def get_1d_sine_pe(pos_inds, dim, temperature=10000):
+ """
+ Get 1D sine positional embedding as in the original Transformer paper.
+ """
+ pe_dim = dim // 2
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
+
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
+ return pos_embed
+
+
+@auto_docstring
+class EdgeTamVideoModel(EdgeTamVideoPreTrainedModel):
+ _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
+ # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
+ _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
+ _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamVideoTwoWayAttentionBlock, index=2)}
+ _keys_to_ignore_on_load_unexpected = []
+
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__(config)
+ self.shared_image_embedding = EdgeTamVideoPositionalEmbedding(config.prompt_encoder_config)
+ self.vision_encoder = AutoModel.from_config(config.vision_config)
+ self.prompt_encoder = EdgeTamVideoPromptEncoder(config.prompt_encoder_config)
+ # The module using it is not a PreTrainedModel subclass so we need this
+ config.mask_decoder_config._attn_implementation = config._attn_implementation
+ self.mask_decoder = EdgeTamVideoMaskDecoder(config.mask_decoder_config)
+
+ self.num_feature_levels = config.vision_config.num_feature_levels
+ self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
+ # a single token to indicate no memory embedding from previous frames
+ self.hidden_dim = config.vision_config.fpn_hidden_size
+ self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ self.config = config
+ # For video sequence inference
+ self.image_size = config.image_size
+ self.memory_attention = EdgeTamVideoMemoryAttention(config)
+ self.memory_encoder = EdgeTamVideoMemoryEncoder(config)
+ self.no_memory_positional_encoding = torch.nn.Parameter(
+ torch.zeros(1, 1, config.vision_config.fpn_hidden_size)
+ )
+ self.mem_dim = config.memory_encoder_output_channels
+ self.num_maskmem = config.num_maskmem # Number of memories accessible
+ # Temporal encoding of the memories
+ self.memory_temporal_positional_encoding = torch.nn.Parameter(
+ torch.zeros(self.num_maskmem, 1, 1, self.mem_dim)
+ )
+
+ self.no_object_pointer = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
+ # a feedforward layer on SAM output tokens to turn them into object pointers
+ self.object_pointer_proj = EdgeTamVideoFeedForward(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
+
+ if self.config.enable_temporal_pos_encoding_for_object_pointers:
+ # a linear projection on temporal positional encoding in object pointers to
+ # avoid potential interference with spatial positional encoding
+ self.temporal_positional_encoding_projection_layer = torch.nn.Linear(self.hidden_dim, self.mem_dim)
+ else:
+ self.temporal_positional_encoding_projection_layer = torch.nn.Identity()
+
+ self.occlusion_spatial_embedding_parameter = None # compatibility with Sam2
+ if config.enable_occlusion_spatial_embedding:
+ self.occlusion_spatial_embedding_parameter = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
+ self.spatial_perceiver = EdgeTamVideoPerceiverResampler(config)
+
+ self.post_init()
+
+ def _tie_weights(self):
+ self.prompt_encoder.shared_embedding.positional_embedding.data = (
+ self.shared_image_embedding.positional_embedding.data
+ )
+
+ def get_input_embeddings(self):
+ return self.vision_encoder.get_input_embeddings()
+
+ def get_image_wide_positional_embeddings(self) -> torch.Tensor:
+ size = self.prompt_encoder.image_embedding_size
+ target_device = self.shared_image_embedding.positional_embedding.device
+ target_dtype = self.shared_image_embedding.positional_embedding.dtype
+ grid = torch.ones(size, device=target_device, dtype=target_dtype)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / size[0]
+ x_embed = x_embed / size[1]
+
+ positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
+ return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
+
+ @torch.no_grad()
+ def get_image_embeddings(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> list[torch.Tensor]:
+ r"""
+ Returns the image embeddings by passing the pixel values through the vision encoder.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Input pixel values
+ """
+ batch_size = pixel_values.shape[0]
+ feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs)
+
+ # add no memory embedding to the last feature map
+ feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
+
+ # reshape feature maps to the same shape as the backbone feature sizes
+ image_embeddings = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
+ ]
+
+ return image_embeddings
+
+ @torch.no_grad()
+ def get_prompt_embeddings(
+ self,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
+
+ Args:
+ input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
+ Optional input points for the prompt encoder. The padding of the point is automatically done by the
+ processor. `point_batch_size` refers to the number of masks that we want the model to predict per
+ point. The model will output `point_batch_size` times 3 masks in total.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
+ Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
+ processor, or can be fed by the user.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
+ Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
+ processor. users can also pass manually the input boxes.
+ input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
+ Optional input masks for the prompt encoder.
+ """
+ prompt_output = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ return prompt_output
+
+ @torch.inference_mode()
+ @auto_docstring(custom_intro="Propagate the objects through a streamed video frame.")
+ def forward(
+ self,
+ inference_session: EdgeTamVideoInferenceSession,
+ frame_idx: Optional[int] = None,
+ frame: Optional[torch.Tensor] = None,
+ reverse: bool = False,
+ ) -> EdgeTamVideoSegmentationOutput:
+ r"""
+ inference_session (`EdgeTamVideoInferenceSession`):
+ The video inference session object.
+ frame_idx (`int`, *optional*):
+ The index of the frame on which to run inference. No need to provide when inferring
+ on a new streamed frame.
+ frame (`torch.Tensor`, *optional*):
+ The frame to process. Provide when streaming.
+ reverse (`bool`, *optional*, defaults to `False`):
+ Whether to propagate in reverse.
+ """
+ if frame is not None:
+ frame_idx = inference_session.add_new_frame(frame, frame_idx)
+
+ if frame is not None and inference_session.get_obj_num() == 0:
+ raise ValueError("No objects are provided for tracking; please add inputs first.")
+
+ num_objects = inference_session.get_obj_num()
+ pred_masks_per_obj = [None] * num_objects
+ # Note: We avoid batched inference here because per-object inputs (clicks/masks)
+ # can differ across objects.
+ for obj_idx in range(num_objects):
+ obj_id = inference_session.obj_idx_to_id(obj_idx)
+ has_new_inputs = obj_id in inference_session.obj_with_new_inputs
+ has_cond_output = frame_idx in inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ # If this object has no new inputs and this frame already has a
+ # conditioning output, reuse the cached masks instead of recomputing.
+ if (not has_new_inputs) and has_cond_output:
+ pred_masks = inference_session.get_output(obj_idx, frame_idx, "pred_masks", is_conditioning_frame=True)
+ is_init_cond_frame = True
+ else:
+ # Defaults when there are no new inputs
+ is_init_cond_frame = False
+ point_inputs = None
+ mask_inputs = None
+
+ if has_new_inputs:
+ is_init_cond_frame = frame_idx not in inference_session.frames_tracked_per_obj[obj_idx]
+ if is_init_cond_frame:
+ reverse = False
+ point_inputs = inference_session.point_inputs_per_obj[obj_idx].get(frame_idx, None)
+ mask_inputs = inference_session.mask_inputs_per_obj[obj_idx].get(frame_idx, None)
+ if point_inputs is not None or mask_inputs is not None:
+ inference_session.obj_with_new_inputs.remove(obj_id)
+
+ current_out = self._run_single_frame_inference(
+ inference_session=inference_session,
+ obj_idx=obj_idx,
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=is_init_cond_frame,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ reverse=reverse,
+ run_mem_encoder=True,
+ streaming=frame is not None,
+ )
+ inference_session.store_output(
+ obj_idx, frame_idx, output_value=current_out, is_conditioning_frame=is_init_cond_frame
+ )
+ pred_masks = current_out["pred_masks"]
+
+ pred_masks_per_obj[obj_idx] = pred_masks
+ if not is_init_cond_frame:
+ # only for tracked frames, not for initial conditioning frames
+ inference_session.frames_tracked_per_obj[obj_idx][frame_idx] = {"reverse": reverse}
+
+ # Resize the output mask to the original video resolution (we directly use
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
+ if len(pred_masks_per_obj) > 1:
+ all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
+ else:
+ all_pred_masks = pred_masks_per_obj[0]
+
+ return EdgeTamVideoSegmentationOutput(pred_masks=all_pred_masks, frame_idx=frame_idx)
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple[
+ list[torch.Tensor],
+ list[torch.Tensor],
+ Optional[tuple[torch.FloatTensor, ...]],
+ Optional[tuple[torch.FloatTensor, ...]],
+ ]:
+ r"""
+ Extract and preprocess image features using the vision encoder.
+
+ Args:
+ pixel_values (`torch.FloatTensor`):
+ Input pixel values of shape `(batch_size, num_channels, height, width)`.
+
+ Returns:
+ `tuple`: A tuple containing:
+ - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels.
+ - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level.
+ - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder.
+ - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder.
+ """
+ vision_outputs: EdgeTamVideoVisionEncoderOutput = self.vision_encoder(
+ pixel_values,
+ **kwargs,
+ )
+
+ feature_maps = vision_outputs.fpn_hidden_states
+ feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
+
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ feature_maps = list(feature_maps)
+ feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
+ feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
+
+ # flatten NxCxHxW to HWxNxC
+ feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
+ feature_maps_position_embeddings = [
+ feature_map_position_embedding.flatten(2).permute(2, 0, 1)
+ for feature_map_position_embedding in feature_maps_position_embeddings
+ ]
+
+ return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions
+
+ def _prepare_vision_features(
+ self,
+ inference_session: EdgeTamVideoInferenceSession,
+ frame_idx: int,
+ batch_size: int,
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+ """Prepare vision features for a frame."""
+
+ # Check if features are cached
+ if cached_features := inference_session.cache.get_vision_features(frame_idx):
+ vision_feats = cached_features["vision_feats"]
+ vision_pos_embeds = cached_features["vision_pos_embeds"]
+ else:
+ # Compute features using image encoder
+ image_batch = inference_session.get_frame(frame_idx).unsqueeze(0) # Add batch dimension
+ vision_feats, vision_pos_embeds, _, _ = self.get_image_features(image_batch)
+ # Cache features
+ inference_session.cache.cache_vision_features(
+ frame_idx, {"vision_feats": vision_feats, "vision_pos_embeds": vision_pos_embeds}
+ )
+
+ # Expand to batch size if needed
+ if batch_size > 1:
+ vision_feats = vision_feats.expand(batch_size, -1, -1, -1)
+ vision_pos_embeds = [pe.expand(batch_size, -1, -1, -1) for pe in vision_pos_embeds]
+
+ return vision_feats, vision_pos_embeds
+
+ def _single_frame_forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ input_points: Optional[torch.FloatTensor] = None,
+ input_labels: Optional[torch.LongTensor] = None,
+ input_boxes: Optional[torch.FloatTensor] = None,
+ input_masks: Optional[torch.LongTensor] = None,
+ image_embeddings: Optional[torch.FloatTensor] = None,
+ multimask_output: bool = True,
+ attention_similarity: Optional[torch.FloatTensor] = None,
+ target_embedding: Optional[torch.FloatTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> EdgeTamVideoImageSegmentationOutput:
+ """
+ input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
+ better results. The points can be obtained by passing a list of list of list to the processor that will
+ create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
+ second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
+ per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
+ multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
+ coordinates of the point. If a different number of points is passed either for each image, or for each
+ mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
+ computation of the embedding will be skipped for these points using the labels.
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
+ Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
+ official implementation, there are 3 types of labels
+
+ - `1`: the point is a point that contains the object of interest
+ - `0`: the point is a point that does not contain the object of interest
+ - `-1`: the point corresponds to the background
+
+ We added the label:
+
+ - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
+
+ The padding labels should be automatically done by the processor.
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
+ Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
+ much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
+ that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
+ size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
+ In the order (`x1`, `y1`, `x2`, `y2`):
+
+ - `x1`: the x coordinate of the top left point of the input box
+ - `y1`: the y coordinate of the top left point of the input box
+ - `x2`: the x coordinate of the bottom right point of the input box
+ - `y2`: the y coordinate of the bottom right point of the input box
+ input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
+ SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
+ generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
+ manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
+ Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
+ efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
+ method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
+ multimask_output (`bool`, *optional*):
+ In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
+ bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
+ "best" mask, by specifying `multimask_output=False`.
+ attention_similarity (`torch.FloatTensor`, *optional*):
+ Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
+ model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+ target_embedding (`torch.FloatTensor`, *optional*):
+ Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
+ the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
+ """
+ if not ((pixel_values is None) ^ (image_embeddings is None)):
+ raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
+ if input_points is not None and input_boxes is not None:
+ if input_points.shape[1] != input_boxes.shape[1]:
+ raise ValueError(
+ f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
+ )
+ elif input_points is not None:
+ num_objects = input_points.shape[1]
+ elif input_boxes is not None:
+ num_objects = input_boxes.shape[1]
+ elif input_masks is not None:
+ num_objects = input_masks.shape[1]
+ else:
+ num_objects = 1
+
+ image_positional_embeddings = self.get_image_wide_positional_embeddings()
+ # repeat with batch size
+ batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
+ image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
+
+ vision_attentions = None
+ vision_hidden_states = None
+
+ if pixel_values is not None:
+ feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features(
+ pixel_values,
+ **kwargs,
+ )
+
+ # add no memory embedding to the last feature map
+ feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
+
+ # reshape feature maps to the same shape as the backbone feature sizes
+ image_embeddings = [
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+ for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
+ ]
+
+ if input_points is not None and input_labels is None:
+ input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
+
+ if input_points is None and input_boxes is None:
+ # If no points are provide, pad with an empty point (with label -1)
+ input_points = torch.zeros(
+ batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
+ )
+ input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
+
+ if input_masks is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
+ input_masks = F.interpolate(
+ input_masks.float(),
+ size=self.prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ ).to(input_masks.dtype)
+
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ input_points=input_points,
+ input_labels=input_labels,
+ input_boxes=input_boxes,
+ input_masks=input_masks,
+ )
+ low_res_multimasks, iou_scores, sam_output_tokens, object_score_logits = self.mask_decoder(
+ image_embeddings=image_embeddings[-1],
+ image_positional_embeddings=image_positional_embeddings,
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ high_resolution_features=image_embeddings[:-1],
+ attention_similarity=attention_similarity,
+ target_embedding=target_embedding,
+ **kwargs,
+ )
+
+ is_obj_appearing = object_score_logits > 0
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
+ # consistent with the actual mask prediction
+ low_res_multimasks = torch.where(
+ is_obj_appearing[:, None, None],
+ low_res_multimasks,
+ NO_OBJ_SCORE,
+ )
+
+ # convert masks from possibly bfloat16 (or float16) to float32
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+ high_res_multimasks = (
+ F.interpolate(
+ low_res_multimasks.squeeze(1).float(),
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ .unsqueeze(1)
+ .to(low_res_multimasks.dtype)
+ )
+ sam_output_token = sam_output_tokens[:, :, 0]
+ if multimask_output:
+ # take the best mask prediction (with the highest IoU estimation)
+ best_iou_inds = torch.argmax(iou_scores, dim=-1)
+ batch_inds = torch.arange(batch_size, device=high_res_multimasks.device)
+ object_batch_inds = torch.arange(num_objects, device=high_res_multimasks.device)
+ low_res_masks = low_res_multimasks[batch_inds, object_batch_inds, best_iou_inds]
+ high_res_masks = high_res_multimasks[batch_inds, object_batch_inds, best_iou_inds]
+ if sam_output_tokens.size(2) > 1:
+ sam_output_token = sam_output_tokens[batch_inds, object_batch_inds, best_iou_inds]
+ else:
+ low_res_masks, high_res_masks = low_res_multimasks[:, :, 0], high_res_multimasks[:, :, 0]
+
+ # Extract object pointer from the SAM output token (with occlusion handling)
+ object_pointer = self.object_pointer_proj(sam_output_token)
+ lambda_is_obj_appearing = is_obj_appearing.to(object_pointer.dtype)
+
+ object_pointer = lambda_is_obj_appearing * object_pointer
+ object_pointer = object_pointer + (1 - lambda_is_obj_appearing) * self.no_object_pointer
+
+ return EdgeTamVideoImageSegmentationOutput(
+ iou_scores=iou_scores,
+ pred_masks=low_res_masks,
+ high_res_masks=high_res_masks,
+ object_pointer=object_pointer,
+ object_score_logits=object_score_logits,
+ image_embeddings=image_embeddings,
+ vision_hidden_states=vision_hidden_states,
+ vision_attentions=vision_attentions,
+ )
+
+ def _use_mask_as_output(
+ self,
+ backbone_features: torch.Tensor,
+ high_res_features: list[torch.Tensor],
+ mask_inputs: torch.Tensor,
+ ) -> EdgeTamVideoImageSegmentationOutput:
+ """
+ Directly turn binary `mask_inputs` into a output mask logits without using SAM.
+ (same input and output shapes as in forward above).
+ """
+ # Use -10/+20 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
+ mask_inputs_float = mask_inputs.to(backbone_features[0].dtype)
+ high_res_masks = mask_inputs_float * out_scale + out_bias
+ low_res_masks = F.interpolate(
+ high_res_masks.float(),
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ ).to(backbone_features[0].dtype)
+ # a dummy IoU prediction of all 1's under mask input
+ iou_scores = mask_inputs.new_ones(mask_inputs.size(0), 1).to(backbone_features[0].dtype)
+ # produce an object pointer using the SAM decoder from the mask input
+ object_pointer = self._single_frame_forward(
+ input_masks=self.mask_downsample(mask_inputs_float.to(backbone_features[0].dtype)),
+ image_embeddings=high_res_features + [backbone_features],
+ ).object_pointer
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
+ # on the object_scores from the SAM decoder.
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
+ is_obj_appearing = is_obj_appearing[..., None]
+ lambda_is_obj_appearing = is_obj_appearing.to(backbone_features[0].dtype)
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
+ object_pointer = lambda_is_obj_appearing * object_pointer
+ object_pointer = object_pointer + (1 - lambda_is_obj_appearing) * self.no_object_pointer
+ return EdgeTamVideoImageSegmentationOutput(
+ iou_scores=iou_scores,
+ pred_masks=low_res_masks,
+ high_res_masks=high_res_masks,
+ object_pointer=object_pointer,
+ object_score_logits=object_score_logits,
+ image_embeddings=high_res_features + [backbone_features],
+ )
+
+ def _gather_memory_frame_outputs(
+ self,
+ inference_session: EdgeTamVideoInferenceSession,
+ obj_idx: int,
+ frame_idx: int,
+ track_in_reverse_time: bool = False,
+ ) -> list[tuple[int, dict]]:
+ """
+ Get memory frames from conditioning and non-conditioning outputs.
+
+ Returns:
+ List of (relative_temporal_offset, output_data) tuples.
+ """
+ temporal_positions_and_previous_outputs = []
+
+ # Add conditioning frame outputs (no limit here, as is the case in the original checkpoints)
+ conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ if not conditioning_outputs:
+ raise ValueError(
+ "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame"
+ )
+
+ # Store (temporal_position, output_data) tuples
+ temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()]
+
+ # Add non-conditioning memory frames (up to self.num_maskmem - 1)
+ # These are typically frames tracked by the model without direct user input.
+ # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity.
+ for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1):
+ # relative_temporal_offset: how many frames before (or after if reversing) the current frame
+ if not track_in_reverse_time:
+ previous_frame_idx = frame_idx - relative_temporal_offset
+ else:
+ previous_frame_idx = frame_idx + relative_temporal_offset
+
+ # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
+ output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
+ previous_frame_idx, None
+ )
+
+ temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data))
+
+ return temporal_positions_and_previous_outputs
+
+ def _build_memory_attention_inputs(
+ self,
+ temporal_positions_and_previous_outputs: list[tuple[int, dict]],
+ device: torch.device,
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
+ """
+ Concatenate memory features and positional embeddings from previous frames.
+
+ Returns:
+ Tuple of (memories_to_concatenate, memory_positional_embeddings_to_concatenate).
+ """
+ memories_to_concatenate = []
+ memory_positional_embeddings_to_concatenate = []
+
+ for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs:
+ if prev_output_data is None:
+ continue # Skip if no output data for this temporal position (e.g., padding frames)
+
+ # Load memory features (potentially from CPU to GPU)
+ # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels)
+ memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True)
+ memories_to_concatenate.append(memory_features.permute(1, 0, 2))
+
+ # Spatial positional encoding (potentially from CPU to GPU)
+ spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True)
+ spatial_memory_pos_embed = spatial_memory_pos_embed.squeeze(1).permute(1, 0, 2)
+
+ # Add temporal positional encoding
+ # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim)
+ combined_memory_pos_embed = (
+ spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1]
+ )
+ memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed)
+
+ return memories_to_concatenate, memory_positional_embeddings_to_concatenate
+
+ def _get_object_pointers(
+ self,
+ inference_session: EdgeTamVideoInferenceSession,
+ obj_idx: int,
+ frame_idx: int,
+ num_total_frames: int,
+ device: torch.device,
+ track_in_reverse_time: bool = False,
+ streaming: bool = False,
+ ) -> tuple[list[int], list[torch.Tensor], int]:
+ """
+ Get object pointers and their positional embeddings from past frames.
+
+ Returns:
+ Tuple of (temporal_offsets, pointer_tokens, max_object_pointers_to_use).
+ """
+ temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1
+
+ # Determine max object pointers to use
+ if streaming:
+ max_object_pointers_to_use = self.config.max_object_pointers_in_encoder
+ else:
+ max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder)
+
+ temporal_offsets: list[int] = []
+ pointer_tokens: list[torch.Tensor] = []
+
+ # Add object pointers from selected conditioning frames
+ # Optionally, only include pointers from past frames during evaluation
+ conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ eligible_conditioning_outputs = conditioning_outputs
+ if not self.training:
+ eligible_conditioning_outputs = {
+ temporal_idx: out
+ for temporal_idx, out in conditioning_outputs.items()
+ if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx)
+ }
+
+ for temporal_idx, out_data in eligible_conditioning_outputs.items():
+ temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier
+ temporal_offsets.append(temporal_difference)
+ pointer_tokens.append(out_data["object_pointer"].to(device))
+
+ # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1)
+ for t_diff_offset in range(1, max_object_pointers_to_use):
+ ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset
+ if ref_frame_idx < 0 or (
+ not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames
+ ):
+ break # Stop if frame index is out of bounds
+
+ # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
+ out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
+ ref_frame_idx, None
+ )
+ if out_data is not None:
+ temporal_offsets.append(t_diff_offset)
+ pointer_tokens.append(out_data["object_pointer"].to(device))
+
+ return temporal_offsets, pointer_tokens, max_object_pointers_to_use
+
+ def _process_object_pointers(
+ self,
+ temporal_offsets: list[int],
+ pointer_tokens: list[torch.Tensor],
+ max_object_pointers_to_use: int,
+ batch_size: int,
+ num_channels: int,
+ device: torch.device,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Process object pointers and compute their positional embeddings.
+
+ Returns:
+ Tuple of (object_pointers, object_pointers_pos_embed).
+ """
+ if not pointer_tokens:
+ return None, None
+
+ # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels)
+ object_pointers = torch.stack(pointer_tokens, dim=0)
+
+ if self.config.enable_temporal_pos_encoding_for_object_pointers:
+ max_temporal_diff = float(max_object_pointers_to_use - 1)
+ # Determine dimensionality for temporal positional encoding of pointers
+ pointer_tpos_dim = num_channels
+
+ # Normalize temporal differences before sine PE calculation
+ normalized_temporal_diffs = (
+ torch.tensor(temporal_offsets, device=device, dtype=torch.float32) / max_temporal_diff
+ )
+ sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype)
+ projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe)
+ object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand(-1, batch_size, self.mem_dim)
+ else:
+ object_pointers_pos_embed = object_pointers.new_zeros(
+ len(temporal_offsets), batch_size, self.mem_dim, dtype=object_pointers.dtype
+ )
+
+ if self.mem_dim < num_channels:
+ # If memory dimension is smaller, reshape/split pointers and repeat positional encoding
+ num_splits = num_channels // self.mem_dim
+ object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim)
+ object_pointers = object_pointers.permute(0, 2, 1, 3).flatten(
+ 0, 1
+ ) # (SeqLen_ptr*num_splits, Batch, MemDim)
+ object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0)
+
+ return object_pointers, object_pointers_pos_embed
+
+ def _prepare_memory_conditioned_features(
+ self,
+ inference_session: EdgeTamVideoInferenceSession,
+ frame_idx: int,
+ obj_idx: int,
+ is_initial_conditioning_frame: bool,
+ current_vision_features: list[torch.Tensor],
+ current_vision_positional_embeddings: list[torch.Tensor],
+ num_total_frames: int,
+ track_in_reverse_time: bool = False,
+ streaming: bool = False,
+ ) -> torch.Tensor:
+ """
+ Fuse current frame's visual features with memory from previous frames for enhanced object tracking.
+
+ This method conditions the current frame's visual features on temporal memory from previous frames,
+ enabling consistent object tracking across video sequences. For initial conditioning frames, it uses
+ no-memory embeddings. For subsequent frames, it retrieves and integrates memory features from both
+ conditioning frames (user interactions) and non-conditioning frames (tracked results) via cross-attention.
+
+ Args:
+ inference_session (`EdgeTamVideoInferenceSession`):
+ The video inference session object.
+ frame_idx (`int`):
+ Index of the current frame being processed.
+ obj_idx (`int`):
+ Index of the object being processed.
+ is_initial_conditioning_frame (`bool`):
+ Whether this is an initial conditioning frame with user inputs (True) or a subsequent
+ tracking frame (False).
+ current_vision_features (`torch.Tensor`):
+ Highest-level vision features of shape `(seq_len, batch_size, channels)`.
+ current_vision_positional_embeddings (`torch.Tensor`):
+ Positional embedding tensors corresponding to the highest-level vision features.
+ num_total_frames (`int`):
+ Total number of frames in the video sequence.
+ track_in_reverse_time (`bool`, *optional*, defaults to `False`):
+ Whether tracking is performed in reverse temporal order.
+ streaming (`bool`, *optional*, defaults to `False`):
+ Whether this is streaming inference mode.
+
+ Returns:
+ `torch.Tensor`: Memory-conditioned feature tensor of shape `(batch_size, channels, height, width)`
+ suitable for input to the SAM decoder.
+ """
+ # Get dimensions from the highest-level (lowest-resolution) feature map
+ batch_size = current_vision_features.size(1)
+ num_channels = self.hidden_dim
+ height, width = self.backbone_feature_sizes[-1]
+ device = current_vision_features.device
+
+ # If memory is disabled (e.g., for single image SAM), return current features directly.
+ if self.num_maskmem == 0:
+ # Permute (SeqLen, Batch, Channels) -> (Batch, Channels, SeqLen) then view as (Batch, Channels, Height, Width)
+ # Assuming SeqLen = Height * Width for the last feature map
+ current_feature_map = current_vision_features.permute(1, 2, 0).view(
+ batch_size, num_channels, height, width
+ )
+ return current_feature_map
+
+ # Step 1: Handle initial conditioning frames
+ if is_initial_conditioning_frame:
+ # For initial conditioning frames, no prior memory is used directly in this block.
+ # If configured, directly add a learnable "no memory" embedding.
+ # current_vision_features has shape (SeqLen, Batch, Channels)
+ conditioned_feature_map_flat = current_vision_features + self.no_memory_embedding
+ # Reshape to (Batch, Channels, Height, Width)
+ conditioned_feature_map = conditioned_feature_map_flat.permute(1, 2, 0).view(
+ batch_size, num_channels, height, width
+ )
+ return conditioned_feature_map
+
+ # Step 2: Get memory frames and concatenate their features
+ temporal_positions_and_previous_outputs = self._gather_memory_frame_outputs(
+ inference_session, obj_idx, frame_idx, track_in_reverse_time
+ )
+
+ memories_to_concatenate, memory_positional_embeddings_to_concatenate = self._build_memory_attention_inputs(
+ temporal_positions_and_previous_outputs, device
+ )
+ num_spatial_memory_tokens = len(memories_to_concatenate)
+
+ # Step 3: Get and process object pointers
+ temporal_offsets, pointer_tokens, max_object_pointers_to_use = self._get_object_pointers(
+ inference_session, obj_idx, frame_idx, num_total_frames, device, track_in_reverse_time, streaming
+ )
+
+ num_object_pointer_tokens = 0
+ if pointer_tokens:
+ object_pointers, object_pointers_pos_embed = self._process_object_pointers(
+ temporal_offsets, pointer_tokens, max_object_pointers_to_use, batch_size, num_channels, device
+ )
+
+ if object_pointers is not None:
+ memories_to_concatenate.append(object_pointers)
+ memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed)
+ num_object_pointer_tokens = object_pointers.shape[0]
+
+ # Step 4: Concatenate all retrieved memories and their positional embeddings
+ combined_memory = torch.cat(memories_to_concatenate, dim=0)
+ combined_memory_positional_embeddings = torch.cat(memory_positional_embeddings_to_concatenate, dim=0)
+
+ # Step 5: Forward through the memory attention mechanism
+ conditioned_feature_map_flat = self.memory_attention(
+ current_vision_features=current_vision_features,
+ current_vision_position_embeddings=current_vision_positional_embeddings,
+ memory=combined_memory,
+ memory_posision_embeddings=combined_memory_positional_embeddings, # Corrected typo from API
+ num_object_pointer_tokens=num_object_pointer_tokens,
+ num_spatial_memory_tokens=num_spatial_memory_tokens,
+ )
+
+ # Reshape from (Batch, H*W, Channels) to (Batch, Channels, Height, Width)
+ conditioned_feature_map = (
+ conditioned_feature_map_flat.squeeze(1).permute(0, 2, 1).view(batch_size, num_channels, height, width)
+ )
+ return conditioned_feature_map
+
+ def _use_multimask(self, is_init_cond_frame: bool, point_inputs: Optional[dict]) -> bool:
+ """Whether to use multimask output in the SAM head."""
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(2)
+ multimask_output = (
+ self.config.multimask_output_in_sam
+ and (is_init_cond_frame or self.config.multimask_output_for_tracking)
+ and (self.config.multimask_min_pt_num <= num_pts <= self.config.multimask_max_pt_num)
+ )
+ return multimask_output
+
+ def _run_single_frame_inference(
+ self,
+ inference_session: EdgeTamVideoInferenceSession,
+ frame_idx: int,
+ obj_idx: int,
+ batch_size: int,
+ is_init_cond_frame: bool,
+ point_inputs: Optional[torch.Tensor],
+ mask_inputs: Optional[torch.Tensor],
+ reverse: bool,
+ run_mem_encoder: bool,
+ prev_sam_mask_logits: Optional[torch.Tensor] = None,
+ streaming: bool = False,
+ ) -> dict[str, Any]:
+ """
+ Perform a single tracking step for video object segmentation.
+
+ Args:
+ inference_session (`EdgeTamVideoInferenceSession`):
+ The video inference session object.
+ frame_idx (`int`):
+ Index of the current frame.
+ obj_idx (`int`):
+ Index of the current object.
+ batch_size (`int`):
+ Batch size of the current frame.
+ is_init_cond_frame (`bool`):
+ Whether this is an initial conditioning frame with user inputs.
+ point_inputs (`dict`, *optional*):
+ Point prompt inputs for the current frame.
+ mask_inputs (`torch.Tensor`, *optional*):
+ Mask prompt inputs for the current frame.
+ reverse (`bool`, *optional*, defaults to `False`):
+ Whether to track in reverse time order.
+ run_mem_encoder (`bool`, *optional*, defaults to `True`):
+ Whether to run the memory encoder on predicted masks.
+ prev_sam_mask_logits (`torch.Tensor`, *optional*):
+ Previously predicted SAM mask logits that can be fed with new clicks.
+ streaming (`bool`, *optional*, defaults to `False`):
+ Whether this is streaming inference.
+
+ Returns:
+ `dict`: Dictionary containing the tracking results for the current frame, including:
+ - pred_masks: Predicted low-resolution masks.
+ - object_pointer: Object pointer for memory.
+ - object_score_logits: Object score logits (inference only).
+ - maskmem_features: Memory features for future frames.
+ - maskmem_pos_enc: Memory positional encodings.
+ """
+ # Retrieve correct image features
+ current_vision_feats, current_vision_pos_embeds = self._prepare_vision_features(
+ inference_session, frame_idx, batch_size
+ )
+ # point and mask should not appear as input simultaneously on the same frame
+ if point_inputs is not None and mask_inputs is not None:
+ raise ValueError(
+ "point_inputs and mask_inputs should not appear as input simultaneously on the same frame"
+ )
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
+ if len(current_vision_feats) > 1:
+ high_res_features = [
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
+ for x, s in zip(current_vision_feats[:-1], self.backbone_feature_sizes[:-1])
+ ]
+ else:
+ high_res_features = None
+ if mask_inputs is not None:
+ # We directly output the mask input (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *self.backbone_feature_sizes[-1])
+ sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
+ else:
+ # fused the visual feature with previous memory features in the memory bank
+ pix_feat = self._prepare_memory_conditioned_features(
+ inference_session=inference_session,
+ frame_idx=frame_idx,
+ obj_idx=obj_idx,
+ is_initial_conditioning_frame=is_init_cond_frame,
+ current_vision_features=current_vision_feats[-1],
+ current_vision_positional_embeddings=current_vision_pos_embeds[-1],
+ num_total_frames=inference_session.num_frames,
+ track_in_reverse_time=reverse,
+ streaming=streaming,
+ )
+ # apply SAM-style segmentation head
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
+ if prev_sam_mask_logits is not None:
+ mask_inputs = prev_sam_mask_logits
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
+ sam_outputs = self._single_frame_forward(
+ pixel_values=None, # Vision features already computed
+ input_points=point_inputs["point_coords"] if point_inputs is not None else None,
+ input_labels=point_inputs["point_labels"] if point_inputs is not None else None,
+ input_masks=mask_inputs,
+ image_embeddings=high_res_features + [pix_feat],
+ multimask_output=multimask_output,
+ )
+
+ # Finally run the memory encoder on the predicted mask to encode
+ # it into a new memory feature (which will be used to condition vision features in future frames)
+ maskmem_features = None
+ maskmem_pos_enc = None
+ if run_mem_encoder and self.num_maskmem > 0:
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats[-1],
+ pred_masks_high_res=sam_outputs.high_res_masks,
+ object_score_logits=sam_outputs.object_score_logits,
+ is_mask_from_pts=(point_inputs is not None or mask_inputs is not None),
+ )
+
+ current_out = {
+ "pred_masks": sam_outputs.pred_masks,
+ "object_pointer": sam_outputs.object_pointer,
+ "maskmem_features": maskmem_features if maskmem_features is not None else None,
+ "maskmem_pos_enc": maskmem_pos_enc,
+ }
+ if not self.training:
+ current_out["object_score_logits"] = sam_outputs.object_score_logits
+
+ return current_out
+
+ def _encode_new_memory(
+ self,
+ current_vision_feats: torch.Tensor,
+ pred_masks_high_res: torch.Tensor,
+ object_score_logits: torch.Tensor,
+ is_mask_from_pts: bool,
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+ """Encode the current image and its prediction into a memory feature."""
+ batch_size = current_vision_feats.size(1) # batch size on this frame
+ channels = self.hidden_dim
+ height, width = self.backbone_feature_sizes[-1] # top-level (lowest-resolution) feature size
+ # top-level feature, (HW)BC => BCHW
+ pix_feat = current_vision_feats.permute(1, 2, 0).view(batch_size, channels, height, width)
+ if is_mask_from_pts and not self.training:
+ # binarize the mask logits
+ mask_for_mem = (pred_masks_high_res > 0).to(pred_masks_high_res.dtype)
+ else:
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
+ # apply scale and bias terms to the sigmoid probabilities
+ mask_for_mem = mask_for_mem * self.config.sigmoid_scale_for_mem_enc
+ mask_for_mem = mask_for_mem + self.config.sigmoid_bias_for_mem_enc
+
+ maskmem_features, maskmem_pos_enc = self.memory_encoder(
+ pix_feat,
+ mask_for_mem,
+ )
+ # add a no-object embedding to the spatial memory to indicate that the frame
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
+ if self.occlusion_spatial_embedding_parameter is not None:
+ is_obj_appearing = (object_score_logits > 0).float()
+ maskmem_features += (1 - is_obj_appearing[..., None]) * self.occlusion_spatial_embedding_parameter[
+ ..., None, None
+ ].expand(*maskmem_features.shape)
+
+ maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype)
+ maskmem_features, maskmem_pos_enc = self.spatial_perceiver(maskmem_features, maskmem_pos_enc)
+ maskmem_features = maskmem_features.to(pred_masks_high_res.dtype)
+ maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype)
+
+ return maskmem_features, maskmem_pos_enc
+
+ @torch.inference_mode()
+ @auto_docstring(
+ custom_intro="""
+ Propagate the objects through the video frames. Used when initializing an inference session with a whole video.
+ Yields EdgeTamVideoSegmentationOutput for each frame.
+ """
+ )
+ def propagate_in_video_iterator(
+ self,
+ inference_session: EdgeTamVideoInferenceSession,
+ start_frame_idx: Optional[int] = None,
+ max_frame_num_to_track: Optional[int] = None,
+ reverse: bool = False,
+ ) -> Iterator[EdgeTamVideoSegmentationOutput]:
+ r"""
+ inference_session (`EdgeTamVideoInferenceSession`):
+ The video inference session object.
+ start_frame_idx (`int`, *optional*):
+ The starting frame index for propagation.
+ Need to be provided if `forward` hasn't been called on new inputs yet.
+ If not provided, the starting frame index will be the earliest frame with input points.
+ max_frame_num_to_track (`int`, *optional*):
+ The maximum number of frames to track.
+ reverse (`bool`, *optional*, defaults to `False`):
+ Whether to propagate in reverse.
+ """
+ num_frames = inference_session.num_frames
+
+ # set start index, end index, and processing order
+ if start_frame_idx is None:
+ # default: start from the earliest frame with input points
+ frames_with_inputs = [
+ frame_idx
+ for obj_output_dict in inference_session.output_dict_per_obj.values()
+ for frame_idx in obj_output_dict["cond_frame_outputs"]
+ ]
+ if not frames_with_inputs:
+ raise ValueError(
+ "Cannot determine the starting frame index; please specify it manually, or run inference on a frame with inputs first."
+ )
+ start_frame_idx = min(frames_with_inputs)
+ if max_frame_num_to_track is None:
+ # default: track all the frames in the video
+ max_frame_num_to_track = num_frames
+ if reverse:
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
+ if start_frame_idx > 0:
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
+ else:
+ processing_order = [] # skip reverse tracking if starting from frame 0
+ else:
+ end_frame_idx = min(start_frame_idx + max_frame_num_to_track, num_frames - 1)
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
+
+ for frame_idx in tqdm(processing_order, desc="propagate in video"):
+ edgetam_video_output = self(inference_session, frame_idx=frame_idx, reverse=reverse)
+ yield edgetam_video_output
+
+
+__all__ = ["EdgeTamVideoModel", "EdgeTamVideoInferenceSession", "EdgeTamVideoPreTrainedModel"]
diff --git a/src/transformers/models/edgetam_video/modular_edgetam_video.py b/src/transformers/models/edgetam_video/modular_edgetam_video.py
new file mode 100644
index 000000000000..b520cd5a756b
--- /dev/null
+++ b/src/transformers/models/edgetam_video/modular_edgetam_video.py
@@ -0,0 +1,1243 @@
+# coding=utf-8
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Callable, Optional
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch import Tensor
+
+from transformers.models.sam2.modeling_sam2 import (
+ eager_attention_forward,
+ window_partition,
+)
+from transformers.utils.generic import OutputRecorder
+
+from ...activations import ACT2FN
+from ...configuration_utils import PretrainedConfig
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
+from ...pytorch_utils import compile_compatible_method_lru_cache
+from ...utils import (
+ auto_docstring,
+)
+from ..auto import CONFIG_MAPPING, AutoConfig
+from ..sam2_video.configuration_sam2_video import (
+ Sam2VideoConfig,
+ Sam2VideoMaskDecoderConfig,
+ Sam2VideoPromptEncoderConfig,
+)
+from ..sam2_video.modeling_sam2_video import (
+ Sam2VideoAttention,
+ Sam2VideoFeedForward,
+ Sam2VideoInferenceSession,
+ Sam2VideoLayerNorm,
+ Sam2VideoMemoryAttention,
+ Sam2VideoMemoryEncoder,
+ Sam2VideoMemoryFuserCXBlock,
+ Sam2VideoModel,
+ Sam2VideoPositionEmbeddingSine,
+ Sam2VideoPreTrainedModel,
+ Sam2VideoTwoWayAttentionBlock,
+ Sam2VideoVisionEncoderOutput,
+ Sam2VideoVisionRotaryEmbedding,
+ rotate_pairwise,
+)
+
+
+class EdgeTamVideoPromptEncoderConfig(Sam2VideoPromptEncoderConfig):
+ pass
+
+
+class EdgeTamVideoMaskDecoderConfig(Sam2VideoMaskDecoderConfig):
+ pass
+
+
+class EdgeTamVideoConfig(Sam2VideoConfig):
+ r"""
+ [`EdgeTamVideoConfig`] is the configuration class to store the configuration of a [`EdgeTamVideoModel`]. It is used to instantiate a
+ EDGETAM model according to the specified arguments, defining the memory attention, memory encoder, and image encoder
+ configs. Instantiating a configuration defaults will yield a similar configuration to that of the SAM 2.1 Hiera-tiny
+ [facebook/EdgeTAM](https://huggingface.co/facebook/EdgeTAM) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (Union[`dict`, `EdgeTamVideoVisionConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`EdgeTamVideoVisionConfig`].
+ prompt_encoder_config (Union[`dict`, `EdgeTamVideoPromptEncoderConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`EdgeTamVideoPromptEncoderConfig`].
+ mask_decoder_config (Union[`dict`, `EdgeTamVideoMaskDecoderConfig`], *optional*):
+ Dictionary of configuration options used to initialize [`EdgeTamMaskDecoderConfig`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ Standard deviation for parameter initialization.
+ num_maskmem (`int`, *optional*, defaults to 7):
+ The number of memory slots for the mask memory.
+ image_size (`int`, *optional*, defaults to 1024):
+ The size of the input images.
+ sigmoid_scale_for_mem_enc (`float`, *optional*, defaults to 20.0):
+ Scale factor for the sigmoid function in the memory encoder.
+ sigmoid_bias_for_mem_enc (`float`, *optional*, defaults to -10.0):
+ Bias for the sigmoid function in the memory encoder.
+ enable_occlusion_spatial_embedding (`bool`, *optional*, defaults to `True`):
+ Whether to enable spatial embedding for occlusions.
+ multimask_output_in_sam (`bool`, *optional*, defaults to `True`):
+ Whether to output multiple masks from the SAM head.
+ multimask_min_pt_num (`int`, *optional*, defaults to 0):
+ The minimum number of points to trigger multimask output.
+ multimask_max_pt_num (`int`, *optional*, defaults to 1):
+ The maximum number of points to trigger multimask output.
+ multimask_output_for_tracking (`bool`, *optional*, defaults to `True`):
+ Whether to use multimask output for tracking.
+ max_object_pointers_in_encoder (`int`, *optional*, defaults to 16):
+ The maximum number of object pointers in the encoder.
+ enable_temporal_pos_encoding_for_object_pointers (`bool`, *optional*, defaults to `True`):
+ Whether to enable temporal positional encoding for object pointers.
+ memory_attention_hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the memory attention hidden states.
+ memory_attention_num_layers (`int`, *optional*, defaults to 2):
+ The number of layers in the memory attention module.
+ memory_attention_num_attention_heads (`int`, *optional*, defaults to 1):
+ Number of attention heads for each attention layer in the memory attention.
+ memory_attention_downsample_rate (`int`, *optional*, defaults to 1):
+ The downsample rate for the attention layers.
+ memory_attention_mlp_hidden_size (`int`, *optional*, defaults to 2048):
+ The dimension of the feedforward network in the memory attention module.
+ memory_attention_mlp_hidden_act (`str`, *optional*, defaults to `"relu"`):
+ The non-linear activation function in the feedforward network in the memory attention module.
+ memory_attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout rate for the memory attention module.
+ memory_attention_rope_theta (`float`, *optional*, defaults to 10000):
+ The Rope theta parameter.
+ memory_attention_rope_feat_sizes (`Tuple[int, int]`, *optional*, defaults to `[64, 64]`):
+ The feature sizes for the Rope positional encoding.
+ memory_attention_rope_k_sizes (`List[int]`, *optional*, defaults to `[16, 16]`):
+ The key feature sizes for the RoPE positional encoding in memory attention.
+ memory_attention_rope_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout rate for the Rope positional encoding.
+ perceiver_resampler_num_latents (`int`, *optional*, defaults to 256):
+ The number of 1D latent tokens in the perceiver resampler.
+ perceiver_resampler_num_latents_2d (`int`, *optional*, defaults to 256):
+ The number of 2D latent tokens in the perceiver resampler.
+ perceiver_resampler_hidden_size (`int`, *optional*, defaults to 64):
+ The hidden size of the perceiver resampler.
+ perceiver_resampler_mlp_intermediate_size (`int`, *optional*, defaults to 256):
+ The intermediate size of the feedforward network in the perceiver resampler.
+ perceiver_resampler_num_attention_heads (`int`, *optional*, defaults to 1):
+ The number of attention heads in the perceiver resampler.
+ perceiver_resampler_attention_head_dim (`int`, *optional*, defaults to 64):
+ The dimension of each attention head in the perceiver resampler.
+ perceiver_resampler_num_layers (`int`, *optional*, defaults to 2):
+ The number of layers in the perceiver resampler.
+ perceiver_resampler_hidden_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout rate for the hidden layers in the perceiver resampler.
+ perceiver_resampler_attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout rate for the attention layers in the perceiver resampler.
+ memory_encoder_hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the memory encoder hidden states.
+ memory_encoder_output_channels (`int`, *optional*, defaults to 64):
+ The number of output channels for the memory encoder.
+ mask_downsampler_embed_dim (`int`, *optional*, defaults to 256):
+ The dimension of the mask downsampler embedding.
+ memory_fuser_intermediate_dim (`int`, *optional*, defaults to 1024):
+ The intermediate dimension of the memory fuser feedforward network.
+ mask_downsampler_kernel_size (`int`, *optional*, defaults to 3):
+ The kernel size for the mask downsampler.
+ mask_downsampler_stride (`int`, *optional*, defaults to 2):
+ The stride for the mask downsampler.
+ mask_downsampler_padding (`int`, *optional*, defaults to 1):
+ The padding for the mask downsampler.
+ mask_downsampler_total_stride (`int`, *optional*, defaults to 16):
+ The total stride for the mask downsampler.
+ mask_downsampler_hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the mask downsampler.
+ memory_fuser_num_layers (`int`, *optional*, defaults to 2):
+ The number of layers in the memory fuser.
+ memory_fuser_embed_dim (`int`, *optional*, defaults to 256):
+ The dimension of the memory fuser embedding.
+ memory_fuser_kernel_size (`int`, *optional*, defaults to 7):
+ The kernel size for the memory fuser.
+ memory_fuser_padding (`int`, *optional*, defaults to 3):
+ The padding for the memory fuser.
+ memory_fuser_layer_scale_init_value (`float`, *optional*, defaults to 1e-06):
+ The initial value for the layer scale in the memory fuser.
+ memory_fuser_hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the memory fuser.
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... EdgeTamVisionConfig,
+ ... EdgeTamVideoPromptEncoderConfig,
+ ... EdgeTamVideoMaskDecoderConfig,
+ ... EdgeTamVideoModel,
+ ... EdgeTamVideoConfig,
+ ... )
+
+ >>> # Initializing a EdgeTamVideoConfig with `"facebook/edgetam.1_hiera_tiny"` style configuration
+ >>> configuration = EdgeTamVideoConfig()
+
+ >>> # Initializing a EdgeTamVideoModel (with random weights) from the `"facebook/edgetam.1_hiera_tiny"` style configuration
+ >>> model = EdgeTamVideoModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+
+ >>> # We can also initialize a EdgeTamConfig from a EdgeTamVisionConfig, EdgeTamPromptEncoderConfig, and EdgeTamMaskDecoderConfig
+
+ >>> # Initializing EDGETAM vision encoder, memory attention, and memory encoder configurations
+ >>> vision_config = EdgeTamVisionConfig()
+ >>> prompt_encoder_config = EdgeTamVideoPromptEncoderConfig()
+ >>> mask_decoder_config = EdgeTamVideoMaskDecoderConfig()
+
+ >>> config = EdgeTamVideoConfig(vision_config, prompt_encoder_config, mask_decoder_config)
+ ```"""
+
+ model_type = "edgetam_video"
+ sub_configs = {
+ "vision_config": AutoConfig,
+ "prompt_encoder_config": EdgeTamVideoPromptEncoderConfig,
+ "mask_decoder_config": EdgeTamVideoMaskDecoderConfig,
+ }
+
+ def __init__(
+ self,
+ vision_config=None,
+ prompt_encoder_config=None,
+ mask_decoder_config=None,
+ initializer_range=0.02,
+ num_maskmem=7,
+ image_size=1024,
+ sigmoid_scale_for_mem_enc=20.0,
+ sigmoid_bias_for_mem_enc=-10.0,
+ enable_occlusion_spatial_embedding=True,
+ multimask_output_in_sam=True,
+ multimask_min_pt_num=0,
+ multimask_max_pt_num=1,
+ multimask_output_for_tracking=True,
+ max_object_pointers_in_encoder=16,
+ enable_temporal_pos_encoding_for_object_pointers=True,
+ # memory attention
+ memory_attention_hidden_size=256,
+ memory_attention_num_layers=2,
+ memory_attention_num_attention_heads=1,
+ memory_attention_downsample_rate=1,
+ memory_attention_mlp_hidden_size=2048,
+ memory_attention_mlp_hidden_act="relu",
+ memory_attention_dropout=0.1,
+ memory_attention_rope_theta=10000,
+ memory_attention_rope_feat_sizes=None,
+ memory_attention_rope_k_sizes=None,
+ memory_attention_rope_dropout=0.1,
+ # spatial perceiver resampler
+ perceiver_resampler_num_latents=256,
+ perceiver_resampler_num_latents_2d=256,
+ perceiver_resampler_hidden_size=64,
+ perceiver_resampler_mlp_intermediate_size=256,
+ perceiver_resampler_num_attention_heads=1,
+ perceiver_resampler_attention_head_dim=64,
+ perceiver_resampler_num_layers=2,
+ perceiver_resampler_hidden_dropout=0.0,
+ perceiver_resampler_attention_dropout=0.0,
+ # memory encoder
+ memory_encoder_hidden_size=256,
+ memory_encoder_output_channels=64,
+ mask_downsampler_embed_dim=256,
+ memory_fuser_intermediate_dim=1024,
+ mask_downsampler_kernel_size=3,
+ mask_downsampler_stride=2,
+ mask_downsampler_padding=1,
+ mask_downsampler_total_stride=16,
+ mask_downsampler_hidden_act="gelu",
+ memory_fuser_num_layers=2,
+ memory_fuser_embed_dim=256,
+ memory_fuser_kernel_size=7,
+ memory_fuser_padding=3,
+ memory_fuser_layer_scale_init_value=1e-6,
+ memory_fuser_hidden_act="gelu",
+ **kwargs,
+ ):
+ PretrainedConfig.__init__(**kwargs)
+ vision_config = vision_config if vision_config is not None else {}
+ prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {}
+ mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {}
+ memory_attention_rope_feat_sizes = (
+ [64, 64] if memory_attention_rope_feat_sizes is None else memory_attention_rope_feat_sizes
+ )
+ memory_attention_rope_k_sizes = (
+ [16, 16] if memory_attention_rope_k_sizes is None else memory_attention_rope_k_sizes
+ )
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "sam2_vision_model")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ if isinstance(prompt_encoder_config, EdgeTamVideoPromptEncoderConfig):
+ prompt_encoder_config = prompt_encoder_config.to_dict()
+ if isinstance(mask_decoder_config, EdgeTamVideoMaskDecoderConfig):
+ mask_decoder_config = mask_decoder_config.to_dict()
+
+ self.vision_config = vision_config
+ self.prompt_encoder_config = EdgeTamVideoPromptEncoderConfig(**prompt_encoder_config)
+ self.mask_decoder_config = EdgeTamVideoMaskDecoderConfig(**mask_decoder_config)
+
+ self.initializer_range = initializer_range
+ self.num_maskmem = num_maskmem # default 1 input frame + 6 previous frames
+ self.image_size = image_size
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc # scale factor for mask sigmoid prob
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc # bias factor for mask sigmoid prob
+ self.enable_occlusion_spatial_embedding = enable_occlusion_spatial_embedding
+ self.multimask_output_in_sam = multimask_output_in_sam
+ self.multimask_min_pt_num = multimask_min_pt_num
+ self.multimask_max_pt_num = multimask_max_pt_num
+ self.multimask_output_for_tracking = multimask_output_for_tracking
+ self.max_object_pointers_in_encoder = max_object_pointers_in_encoder
+ self.enable_temporal_pos_encoding_for_object_pointers = enable_temporal_pos_encoding_for_object_pointers
+
+ # memory attention
+ self.memory_attention_hidden_size = memory_attention_hidden_size
+ self.memory_attention_num_layers = memory_attention_num_layers
+ self.memory_attention_num_attention_heads = memory_attention_num_attention_heads
+ self.memory_attention_downsample_rate = memory_attention_downsample_rate
+ self.memory_attention_mlp_hidden_size = memory_attention_mlp_hidden_size
+ self.memory_attention_mlp_hidden_act = memory_attention_mlp_hidden_act
+ self.memory_attention_dropout = memory_attention_dropout
+ self.memory_attention_rope_theta = memory_attention_rope_theta
+ self.memory_attention_rope_feat_sizes = memory_attention_rope_feat_sizes
+ self.memory_attention_rope_k_sizes = memory_attention_rope_k_sizes
+ self.memory_attention_rope_dropout = memory_attention_rope_dropout
+
+ # spatial perceiver resampler
+ self.perceiver_resampler_num_latents = perceiver_resampler_num_latents
+ self.perceiver_resampler_num_latents_2d = perceiver_resampler_num_latents_2d
+ self.perceiver_resampler_hidden_size = perceiver_resampler_hidden_size
+ self.perceiver_resampler_mlp_intermediate_size = perceiver_resampler_mlp_intermediate_size
+ self.perceiver_resampler_attention_head_dim = perceiver_resampler_attention_head_dim
+ self.perceiver_resampler_num_attention_heads = perceiver_resampler_num_attention_heads
+ self.perceiver_resampler_num_layers = perceiver_resampler_num_layers
+ self.perceiver_resampler_hidden_dropout = perceiver_resampler_hidden_dropout
+ self.perceiver_resampler_attention_dropout = perceiver_resampler_attention_dropout
+
+ # memory encoder
+ self.memory_encoder_hidden_size = memory_encoder_hidden_size
+ self.memory_encoder_output_channels = memory_encoder_output_channels
+ self.mask_downsampler_embed_dim = mask_downsampler_embed_dim
+ self.mask_downsampler_kernel_size = mask_downsampler_kernel_size
+ self.mask_downsampler_stride = mask_downsampler_stride
+ self.mask_downsampler_padding = mask_downsampler_padding
+ self.mask_downsampler_total_stride = mask_downsampler_total_stride
+ self.mask_downsampler_hidden_act = mask_downsampler_hidden_act
+ self.memory_fuser_num_layers = memory_fuser_num_layers
+ self.memory_fuser_embed_dim = memory_fuser_embed_dim
+ self.memory_fuser_intermediate_dim = memory_fuser_intermediate_dim
+ self.memory_fuser_kernel_size = memory_fuser_kernel_size
+ self.memory_fuser_padding = memory_fuser_padding
+ self.memory_fuser_layer_scale_init_value = memory_fuser_layer_scale_init_value
+ self.memory_fuser_hidden_act = memory_fuser_hidden_act
+
+
+class EdgeTamVideoLayerNorm(Sam2VideoLayerNorm):
+ pass
+
+
+class EdgeTamVideoMemoryFuserCXBlock(Sam2VideoMemoryFuserCXBlock):
+ pass
+
+
+class EdgeTamVideoVisionEncoderOutput(Sam2VideoVisionEncoderOutput):
+ pass
+
+
+class EdgeTamVideoVisionRotaryEmbedding(Sam2VideoVisionRotaryEmbedding):
+ def __init__(self, config: EdgeTamVideoConfig, end_x: Optional[int] = None, end_y: Optional[int] = None):
+ nn.Module.__init__()
+ dim = config.memory_attention_hidden_size // (
+ config.memory_attention_downsample_rate * config.memory_attention_num_attention_heads
+ )
+ # Ensure even dimension for proper axial splitting
+ if dim % 4 != 0:
+ raise ValueError("Dimension must be divisible by 4 for axial RoPE")
+ end_x, end_y = config.memory_attention_rope_feat_sizes if end_x is None else (end_x, end_y)
+ freqs = 1.0 / (config.memory_attention_rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+
+ # Generate 2D position indices for axial rotary embedding
+ flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
+ x_positions = flattened_indices % end_x
+ y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor")
+ freqs_x = torch.outer(x_positions, freqs).float()
+ freqs_y = torch.outer(y_positions, freqs).float()
+ inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
+ inv_freq = inv_freq.repeat_interleave(2, dim=-1)
+ # directly register the cos and sin embeddings as we have a fixed feature shape
+ self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
+ self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
+
+
+class EdgeTamVideoAttention(Sam2VideoAttention):
+ pass
+
+
+def apply_rotary_pos_emb_2d_self_attn(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary position embedding to query and key tensors for self-attention.
+
+ Args:
+ q: Query tensor of shape (..., seq_len, head_dim)
+ k: Key tensor of shape (..., seq_len, head_dim)
+ cos: Cosine position embedding of shape (seq_len, head_dim)
+ sin: Sine position embedding of shape (seq_len, head_dim)
+
+ Returns:
+ Rotated (q, k) tensors
+ """
+ # Apply RoPE to queries
+ q_embed = q.float() # force upscale to float32 as in the original implementation
+ q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)
+
+ # Apply RoPE to keys (same embeddings as queries for self-attention)
+ k_embed = k.float() # force upscale to float32 as in the original implementation
+ k_embed = (k_embed * cos) + (rotate_pairwise(k_embed) * sin)
+
+ return q_embed.type_as(q), k_embed.type_as(k)
+
+
+def apply_rotary_pos_emb_2d_cross_attn(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ cos_k: torch.Tensor,
+ sin_k: torch.Tensor,
+ num_k_exclude_rope: int = 0,
+ repeat_freqs_k: int = 1,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary position embedding to query and key tensors for cross-attention.
+
+ Args:
+ q: Query tensor of shape (..., seq_len, head_dim)
+ k: Key tensor of shape (..., seq_len, head_dim)
+ cos: Cosine position embedding of shape (seq_len, head_dim)
+ sin: Sine position embedding of shape (seq_len, head_dim)
+ cos_k: Cosine position embedding for keys of shape (seq_len, head_dim)
+ sin_k: Sine position embedding for keys of shape (seq_len, head_dim)
+ num_k_exclude_rope: Number of tokens at end of k to exclude from RoPE (e.g., object pointer tokens)
+ repeat_freqs_k: Frequency repetition for keys in cross-attention (e.g., for spatial memory tokens)
+
+ Returns:
+ Rotated (q, k) tensors
+ """
+ # Apply RoPE to queries (always straightforward)
+ q_embed = q.float()
+ q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)
+
+ # Split keys: RoPE tokens and excluded tokens (e.g., object pointers)
+ num_total_k_tokens = k.shape[-2]
+ k_for_rope = k[..., : num_total_k_tokens - num_k_exclude_rope, :]
+ k_excluded = k[..., num_total_k_tokens - num_k_exclude_rope :, :]
+
+ # Early return if no keys need RoPE
+ if k_for_rope.shape[-2] == 0:
+ return q_embed.type_as(q), k_excluded
+
+ batch_size, num_heads, k_seq_len, channels_per_head = k_for_rope.shape
+
+ # Handle temporal/spatial token structure for memory
+ # Keys have temporal + spatial structure, only spatial tokens get RoPE
+ tokens_per_group = k_seq_len // repeat_freqs_k
+ spatial_tokens = cos_k.shape[-2]
+ temporal_tokens = tokens_per_group - spatial_tokens
+
+ # Reshape and separate temporal/spatial tokens
+ k_grouped = k_for_rope.view(batch_size, num_heads, repeat_freqs_k, tokens_per_group, channels_per_head)
+ k_temporal = k_grouped[..., :temporal_tokens, :].reshape(batch_size, num_heads, -1, channels_per_head)
+ k_spatial = k_grouped[..., temporal_tokens:, :].reshape(batch_size, num_heads, -1, channels_per_head)
+
+ # Only apply RoPE to spatial tokens
+ k_rope_input = k_spatial
+
+ # Prepare position embeddings for repeated groups
+ if repeat_freqs_k > 1:
+ cos_k = cos_k.repeat(1, 1, repeat_freqs_k, 1)
+ sin_k = sin_k.repeat(1, 1, repeat_freqs_k, 1)
+
+ # Apply RoPE to spatial tokens
+ k_spatial_embed = k_rope_input.float()
+ k_spatial_embed = (k_spatial_embed * cos_k) + (rotate_pairwise(k_spatial_embed) * sin_k)
+
+ # Reconstruct: temporal + spatial tokens back to original structure
+ k_spatial_reshaped = k_spatial_embed.view(batch_size, num_heads, repeat_freqs_k, -1, channels_per_head)
+ k_temporal_reshaped = k_temporal.view(batch_size, num_heads, repeat_freqs_k, -1, channels_per_head)
+ k_final = torch.cat([k_temporal_reshaped, k_spatial_reshaped], dim=3)
+ k_final = k_final.view(batch_size, num_heads, k_seq_len, channels_per_head)
+
+ # Combine RoPE-processed keys with excluded tokens
+ k_embed = torch.cat([k_final.type_as(k), k_excluded], dim=-2)
+ return q_embed.type_as(q), k_embed
+
+
+class EdgeTamVideoRoPESelfAttention(nn.Module):
+ """Self-attention with rotary position encoding."""
+
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.memory_attention_hidden_size
+ self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate
+ self.num_attention_heads = config.memory_attention_num_attention_heads
+ self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
+ self.dropout_p = config.memory_attention_rope_dropout
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tensor:
+ # Input projections
+ batch_size, point_batch_size = query.shape[:2]
+ new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
+
+ query = self.q_proj(query).view(*new_shape).transpose(1, 2)
+ key = self.k_proj(key).view(*new_shape).transpose(1, 2)
+ value = self.v_proj(value).view(*new_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ # Apply rotary position encoding for self-attention
+ query, key = apply_rotary_pos_emb_2d_self_attn(query, key, cos=cos, sin=sin)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ dropout=0.0 if not self.training else self.dropout_p,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(
+ batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
+ ).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class EdgeTamVideoRoPECrossAttention(nn.Module):
+ """Cross-attention with rotary position encoding."""
+
+ def __init__(self, config: EdgeTamVideoConfig, kv_in_dim: int):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.memory_attention_hidden_size
+ self.internal_dim = self.hidden_size // config.memory_attention_downsample_rate
+ self.num_attention_heads = config.memory_attention_num_attention_heads
+ self.head_dim = self.internal_dim // config.memory_attention_num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.kv_in_dim = kv_in_dim
+
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
+ self.dropout_p = config.memory_attention_rope_dropout
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ position_embeddings_k: tuple[torch.Tensor, torch.Tensor],
+ num_k_exclude_rope: int = 0,
+ rope_k_repeat: int = 0,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tensor:
+ # Input projections
+ batch_size, point_batch_size = query.shape[:2]
+ new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
+
+ query = self.q_proj(query).view(*new_shape).transpose(1, 2)
+ key = self.k_proj(key).view(*new_shape).transpose(1, 2)
+ value = self.v_proj(value).view(*new_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ cos_k, sin_k = position_embeddings_k
+ # Apply rotary position encoding for cross-attention
+ query, key = apply_rotary_pos_emb_2d_cross_attn(
+ query,
+ key,
+ cos=cos,
+ sin=sin,
+ cos_k=cos_k,
+ sin_k=sin_k,
+ repeat_freqs_k=rope_k_repeat,
+ num_k_exclude_rope=num_k_exclude_rope,
+ )
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ dropout=0.0 if not self.training else self.dropout_p,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(
+ batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
+ ).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class EdgeTamVideoTwoWayAttentionBlock(Sam2VideoTwoWayAttentionBlock):
+ pass
+
+
+class EdgeTamVideoPositionEmbeddingSine(Sam2VideoPositionEmbeddingSine):
+ # maxsize=2 because we need to cache the forward method for both memory encoder and perceiver resampler
+ @compile_compatible_method_lru_cache(maxsize=2)
+ def forward(self, **super_kwargs):
+ return super().forward(**super_kwargs)
+
+
+class EdgeTamVideoMemoryEncoder(Sam2VideoMemoryEncoder):
+ pass
+
+
+class EdgeTamVideoFeedForward(Sam2VideoFeedForward):
+ pass
+
+
+class EdgeTamVideoPreTrainedModel(Sam2VideoPreTrainedModel):
+ pass
+
+
+class EdgeTamVideoInferenceSession(Sam2VideoInferenceSession):
+ pass
+
+
+class EdgeTamVideoMemoryAttentionMLP(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.memory_attention_hidden_size
+ self.intermediate_size = config.memory_attention_mlp_hidden_size
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size)
+ self.dropout = nn.Dropout(config.memory_attention_dropout)
+ self.act_fn = ACT2FN[config.memory_attention_mlp_hidden_act]
+
+ def forward(self, x):
+ return self.down_proj(self.dropout(self.act_fn(self.up_proj(x))))
+
+
+class EdgeTamVideoMemoryAttentionLayer(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ hidden_size = config.memory_attention_hidden_size
+ self.self_attn = EdgeTamVideoRoPESelfAttention(config)
+ self.cross_attn_image = EdgeTamVideoRoPECrossAttention(config, kv_in_dim=64)
+
+ # MLP module
+ self.mlp = EdgeTamVideoMemoryAttentionMLP(config)
+
+ self.layer_norm1 = nn.LayerNorm(hidden_size)
+ self.layer_norm2 = nn.LayerNorm(hidden_size)
+ self.layer_norm3 = nn.LayerNorm(hidden_size)
+ self.dropout1 = nn.Dropout(config.memory_attention_dropout)
+ self.dropout2 = nn.Dropout(config.memory_attention_dropout)
+ self.dropout3 = nn.Dropout(config.memory_attention_dropout)
+
+ def forward(
+ self,
+ queries: Tensor,
+ keys: Tensor,
+ key_point_embedding: Tensor,
+ rope_position_embeddings: tuple[Tensor, Tensor],
+ rope_position_embeddings_k: Optional[tuple[Tensor, Tensor]] = None,
+ num_k_exclude_rope: int = 0,
+ rope_k_repeat: int = 0,
+ ) -> torch.Tensor:
+ # Self-Attention
+ query = self.layer_norm1(queries)
+ query, _ = self.self_attn(query=query, key=query, value=query, position_embeddings=rope_position_embeddings)
+ queries = queries + self.dropout1(query)
+
+ # Cross-Attention
+ query = self.layer_norm2(queries)
+ query, _ = self.cross_attn_image(
+ query=query,
+ key=keys + key_point_embedding,
+ value=keys,
+ position_embeddings=rope_position_embeddings,
+ position_embeddings_k=rope_position_embeddings_k,
+ num_k_exclude_rope=num_k_exclude_rope,
+ rope_k_repeat=rope_k_repeat,
+ )
+ queries = queries + self.dropout2(query)
+ # MLP
+ query = self.layer_norm3(queries)
+ query = self.mlp(query)
+ queries = queries + self.dropout3(query)
+ return queries
+
+
+class EdgeTamVideoMemoryAttention(Sam2VideoMemoryAttention):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.rotary_emb_k = EdgeTamVideoVisionRotaryEmbedding(
+ config, end_x=config.memory_attention_rope_k_sizes[0], end_y=config.memory_attention_rope_k_sizes[1]
+ )
+
+ def forward(
+ self,
+ current_vision_features: torch.Tensor,
+ memory: torch.Tensor,
+ current_vision_position_embeddings: Optional[Tensor] = None,
+ memory_posision_embeddings: Optional[Tensor] = None,
+ num_object_pointer_tokens: int = 0,
+ num_spatial_memory_tokens: int = -1,
+ ):
+ """
+ Args:
+ current_vision_features (`torch.FloatTensor`):
+ The current vision features used for self-attention.
+ memory (`torch.FloatTensor`):
+ The memory features used for cross-attention.
+ current_vision_position_embeddings (`torch.FloatTensor`, *optional*):
+ The position embeddings for the current vision features.
+ memory_posision_embeddings (`torch.FloatTensor`, *optional*):
+ The position embeddings for the memory features.
+ num_object_pointer_tokens (`int`, *optional*, defaults to 0):
+ The number of object pointer tokens.
+ """
+ output = current_vision_features
+ if current_vision_position_embeddings is not None:
+ output = output + 0.1 * current_vision_position_embeddings
+
+ # Convert to batch first
+ output = output.transpose(0, 1)
+ memory = memory.transpose(0, 1).unsqueeze(1)
+ memory_posision_embeddings = memory_posision_embeddings.transpose(0, 1).unsqueeze(1)
+ rope_position_embeddings = self.rotary_emb()
+ rope_position_embeddings_k = self.rotary_emb_k()
+ for layer in self.layers:
+ output = layer(
+ queries=output.unsqueeze(1) if output.ndim == 3 else output,
+ keys=memory,
+ key_point_embedding=memory_posision_embeddings,
+ rope_position_embeddings=rope_position_embeddings,
+ rope_position_embeddings_k=rope_position_embeddings_k,
+ num_k_exclude_rope=num_object_pointer_tokens,
+ rope_k_repeat=num_spatial_memory_tokens,
+ )
+
+ normed_output = self.layer_norm(output)
+
+ # Convert back to seq first
+ normed_output = normed_output.transpose(0, 1)
+
+ return normed_output
+
+
+class EdgeTamVideoPerceiverMLP(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.hidden_size = config.perceiver_resampler_hidden_size
+ self.intermediate_size = config.perceiver_resampler_mlp_intermediate_size
+
+ self.layer_norm = nn.LayerNorm(self.hidden_size)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = nn.GELU()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.down_proj(self.act_fn(self.up_proj(hidden_states)))
+ return hidden_states
+
+
+class EdgeTamVideoPerceiverAttention(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.perceiver_resampler_hidden_size
+ self.num_attention_heads = config.perceiver_resampler_num_attention_heads
+ self.head_dim = config.perceiver_resampler_attention_head_dim
+ self.attention_dropout = config.perceiver_resampler_attention_dropout
+
+ self.inner_dim = self.head_dim * self.num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.is_causal = False
+
+ self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
+ self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ positional_encoding: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ # Project queries, keys, and values
+ query = self.q_proj(query)
+ key = self.k_proj(key)
+ value = self.v_proj(value)
+
+ # Reshape for multi-head attention
+ batch_size, seq_len_q = query.shape[:2]
+ query = query.view(batch_size, seq_len_q, self.num_attention_heads, self.head_dim).transpose(1, 2)
+ seq_len_kv = key.shape[1]
+ key = key.view(batch_size, seq_len_kv, self.num_attention_heads, self.head_dim).transpose(1, 2)
+ value = value.view(batch_size, seq_len_kv, self.num_attention_heads, self.head_dim).transpose(1, 2)
+
+ # Add positional encoding if provided
+ if positional_encoding is not None:
+ pos_encoding = positional_encoding.view(
+ batch_size, seq_len_kv, self.num_attention_heads, self.head_dim
+ ).transpose(1, 2)
+ key = key + pos_encoding
+ value = value + pos_encoding
+
+ # Apply attention
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, _ = attention_interface(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ is_causal=self.is_causal,
+ **kwargs,
+ )
+
+ # Reshape output
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.inner_dim)
+ return self.o_proj(attn_output)
+
+
+class EdgeTamVideoPerceiverEncoderLayer(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+
+ self.cross_attention = EdgeTamVideoPerceiverAttention(config)
+ self.mlp = EdgeTamVideoPerceiverMLP(config)
+ self.dropout = nn.Dropout(config.perceiver_resampler_hidden_dropout)
+
+ self.self_attention = EdgeTamVideoPerceiverAttention(config)
+ self.self_mlp = EdgeTamVideoPerceiverMLP(config)
+
+ # Layer norms moved from attention classes to here
+ self.layer_norm_input = nn.LayerNorm(config.perceiver_resampler_hidden_size)
+ self.layer_norm_latents = nn.LayerNorm(config.perceiver_resampler_hidden_size)
+ self.layer_norm_self = nn.LayerNorm(config.perceiver_resampler_hidden_size)
+
+ def forward(
+ self,
+ latents: torch.Tensor,
+ input_features: torch.Tensor,
+ positional_encoding: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # Cross attention with layer norms
+ normalized_latents = self.layer_norm_latents(latents)
+ normalized_input = self.layer_norm_input(input_features)
+ cross_attention_output = self.cross_attention(
+ query=normalized_latents,
+ key=normalized_input,
+ value=normalized_input,
+ positional_encoding=positional_encoding,
+ )
+ latents = latents + self.dropout(cross_attention_output)
+
+ mlp_output = self.mlp(latents)
+ latents = latents + mlp_output
+
+ # Self attention with layer norm
+ normalized_latents_self = self.layer_norm_self(latents)
+ self_attention_output = self.self_attention(
+ query=normalized_latents_self, key=normalized_latents_self, value=normalized_latents_self
+ )
+ latents = latents + self_attention_output
+
+ self_mlp_output = self.self_mlp(latents)
+ latents = latents + self_mlp_output
+
+ return latents
+
+
+class EdgeTamVideoPerceiverResampler(nn.Module):
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.perceiver_resampler_hidden_size
+ self.num_latents_1d = config.perceiver_resampler_num_latents
+ self.num_latents_2d = config.perceiver_resampler_num_latents_2d
+ self.num_layers = config.perceiver_resampler_num_layers
+
+ if self.num_latents_1d > 0:
+ self.latents_1d = nn.Parameter(torch.randn(self.num_latents_1d, self.hidden_size))
+ if self.num_latents_2d > 0:
+ self.latents_2d = nn.Parameter(torch.randn(self.num_latents_2d, self.hidden_size))
+
+ self.positional_encoding = EdgeTamVideoPositionEmbeddingSine(
+ num_pos_feats=self.hidden_size // 2, normalize=True
+ )
+
+ self.layers = nn.ModuleList([EdgeTamVideoPerceiverEncoderLayer(config) for _ in range(self.num_layers)])
+
+ self.layer_norm = nn.LayerNorm(self.hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ positional_encoding: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ output_latents = []
+ output_positional_encodings = []
+
+ if self.num_latents_1d > 0:
+ latents_1d, pos_1d = self._forward_1d(hidden_states, positional_encoding)
+ output_latents.append(latents_1d)
+ output_positional_encodings.append(pos_1d)
+
+ if self.num_latents_2d > 0:
+ latents_2d, pos_2d = self._forward_2d(hidden_states)
+ output_latents.append(latents_2d)
+ output_positional_encodings.append(pos_2d)
+
+ combined_latents = torch.cat(output_latents, dim=1)
+
+ combined_positional_encoding = None
+ if positional_encoding is not None and output_positional_encodings:
+ combined_positional_encoding = torch.cat(output_positional_encodings, dim=1)
+
+ return combined_latents, combined_positional_encoding
+
+ def _forward_1d(
+ self,
+ hidden_states: torch.Tensor,
+ positional_encoding: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ batch_size = hidden_states.shape[0]
+
+ latents = self.latents_1d.unsqueeze(0).expand(batch_size, -1, -1)
+ flattened_features = hidden_states.permute(0, 2, 3, 1).flatten(1, 2)
+
+ positional_features = None
+ if positional_encoding is not None:
+ positional_features = positional_encoding.permute(0, 2, 3, 1).flatten(1, 2)
+
+ for layer in self.layers:
+ latents = layer(latents, flattened_features, positional_features)
+
+ latents = self.layer_norm(latents)
+
+ output_positional_encoding = None
+ if positional_encoding is not None:
+ output_positional_encoding = torch.zeros_like(latents)
+
+ return latents, output_positional_encoding
+
+ def _forward_2d(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ batch_size, channels, height, width = hidden_states.shape
+
+ latents_2d = self.latents_2d.unsqueeze(0).expand(batch_size, -1, -1).view(-1, 1, channels)
+
+ num_windows_per_dim = int(math.sqrt(self.num_latents_2d))
+ window_size = height // num_windows_per_dim
+
+ windowed_input = hidden_states.permute(0, 2, 3, 1)
+ windowed_features, _ = window_partition(windowed_input, window_size)
+ windowed_features = windowed_features.flatten(1, 2)
+
+ for layer in self.layers:
+ latents_2d = layer(latents_2d, windowed_features, positional_encoding=None)
+
+ latents_2d = latents_2d.view(batch_size, num_windows_per_dim, num_windows_per_dim, channels).permute(
+ 0, 3, 1, 2
+ )
+
+ positional_encoding_2d = self.positional_encoding(latents_2d.shape, latents_2d.device, latents_2d.dtype).to(
+ dtype=hidden_states.dtype
+ )
+ positional_encoding_2d = positional_encoding_2d.permute(0, 2, 3, 1).flatten(1, 2)
+
+ latents_2d = latents_2d.permute(0, 2, 3, 1).flatten(1, 2)
+ latents_2d = self.layer_norm(latents_2d)
+
+ return latents_2d, positional_encoding_2d
+
+
+@auto_docstring
+class EdgeTamVideoModel(Sam2VideoModel):
+ _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
+ # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
+ _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
+ _keys_to_ignore_on_load_unexpected = []
+ _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamVideoTwoWayAttentionBlock, index=2)}
+
+ def __init__(self, config: EdgeTamVideoConfig):
+ super().__init__(config)
+ self.spatial_perceiver = EdgeTamVideoPerceiverResampler(config)
+
+ self.post_init()
+
+ def _build_memory_attention_inputs(
+ self,
+ temporal_positions_and_previous_outputs: list[tuple[int, dict]],
+ device: torch.device,
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
+ """
+ Concatenate memory features and positional embeddings from previous frames.
+
+ Returns:
+ Tuple of (memories_to_concatenate, memory_positional_embeddings_to_concatenate).
+ """
+ memories_to_concatenate = []
+ memory_positional_embeddings_to_concatenate = []
+
+ for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs:
+ if prev_output_data is None:
+ continue # Skip if no output data for this temporal position (e.g., padding frames)
+
+ # Load memory features (potentially from CPU to GPU)
+ # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels)
+ memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True)
+ memories_to_concatenate.append(memory_features.permute(1, 0, 2))
+
+ # Spatial positional encoding (potentially from CPU to GPU)
+ spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True)
+ spatial_memory_pos_embed = spatial_memory_pos_embed.squeeze(1).permute(1, 0, 2)
+
+ # Add temporal positional encoding
+ # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim)
+ combined_memory_pos_embed = (
+ spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1]
+ )
+ memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed)
+
+ return memories_to_concatenate, memory_positional_embeddings_to_concatenate
+
+ def _prepare_memory_conditioned_features(
+ self,
+ inference_session: EdgeTamVideoInferenceSession,
+ frame_idx: int,
+ obj_idx: int,
+ is_initial_conditioning_frame: bool,
+ current_vision_features: list[torch.Tensor],
+ current_vision_positional_embeddings: list[torch.Tensor],
+ num_total_frames: int,
+ track_in_reverse_time: bool = False,
+ streaming: bool = False,
+ ) -> torch.Tensor:
+ """
+ Fuse current frame's visual features with memory from previous frames for enhanced object tracking.
+
+ This method conditions the current frame's visual features on temporal memory from previous frames,
+ enabling consistent object tracking across video sequences. For initial conditioning frames, it uses
+ no-memory embeddings. For subsequent frames, it retrieves and integrates memory features from both
+ conditioning frames (user interactions) and non-conditioning frames (tracked results) via cross-attention.
+
+ Args:
+ inference_session (`EdgeTamVideoInferenceSession`):
+ The video inference session object.
+ frame_idx (`int`):
+ Index of the current frame being processed.
+ obj_idx (`int`):
+ Index of the object being processed.
+ is_initial_conditioning_frame (`bool`):
+ Whether this is an initial conditioning frame with user inputs (True) or a subsequent
+ tracking frame (False).
+ current_vision_features (`torch.Tensor`):
+ Highest-level vision features of shape `(seq_len, batch_size, channels)`.
+ current_vision_positional_embeddings (`torch.Tensor`):
+ Positional embedding tensors corresponding to the highest-level vision features.
+ num_total_frames (`int`):
+ Total number of frames in the video sequence.
+ track_in_reverse_time (`bool`, *optional*, defaults to `False`):
+ Whether tracking is performed in reverse temporal order.
+ streaming (`bool`, *optional*, defaults to `False`):
+ Whether this is streaming inference mode.
+
+ Returns:
+ `torch.Tensor`: Memory-conditioned feature tensor of shape `(batch_size, channels, height, width)`
+ suitable for input to the SAM decoder.
+ """
+ # Get dimensions from the highest-level (lowest-resolution) feature map
+ batch_size = current_vision_features.size(1)
+ num_channels = self.hidden_dim
+ height, width = self.backbone_feature_sizes[-1]
+ device = current_vision_features.device
+
+ # If memory is disabled (e.g., for single image SAM), return current features directly.
+ if self.num_maskmem == 0:
+ # Permute (SeqLen, Batch, Channels) -> (Batch, Channels, SeqLen) then view as (Batch, Channels, Height, Width)
+ # Assuming SeqLen = Height * Width for the last feature map
+ current_feature_map = current_vision_features.permute(1, 2, 0).view(
+ batch_size, num_channels, height, width
+ )
+ return current_feature_map
+
+ # Step 1: Handle initial conditioning frames
+ if is_initial_conditioning_frame:
+ # For initial conditioning frames, no prior memory is used directly in this block.
+ # If configured, directly add a learnable "no memory" embedding.
+ # current_vision_features has shape (SeqLen, Batch, Channels)
+ conditioned_feature_map_flat = current_vision_features + self.no_memory_embedding
+ # Reshape to (Batch, Channels, Height, Width)
+ conditioned_feature_map = conditioned_feature_map_flat.permute(1, 2, 0).view(
+ batch_size, num_channels, height, width
+ )
+ return conditioned_feature_map
+
+ # Step 2: Get memory frames and concatenate their features
+ temporal_positions_and_previous_outputs = self._gather_memory_frame_outputs(
+ inference_session, obj_idx, frame_idx, track_in_reverse_time
+ )
+
+ memories_to_concatenate, memory_positional_embeddings_to_concatenate = self._build_memory_attention_inputs(
+ temporal_positions_and_previous_outputs, device
+ )
+ num_spatial_memory_tokens = len(memories_to_concatenate)
+
+ # Step 3: Get and process object pointers
+ temporal_offsets, pointer_tokens, max_object_pointers_to_use = self._get_object_pointers(
+ inference_session, obj_idx, frame_idx, num_total_frames, device, track_in_reverse_time, streaming
+ )
+
+ num_object_pointer_tokens = 0
+ if pointer_tokens:
+ object_pointers, object_pointers_pos_embed = self._process_object_pointers(
+ temporal_offsets, pointer_tokens, max_object_pointers_to_use, batch_size, num_channels, device
+ )
+
+ if object_pointers is not None:
+ memories_to_concatenate.append(object_pointers)
+ memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed)
+ num_object_pointer_tokens = object_pointers.shape[0]
+
+ # Step 4: Concatenate all retrieved memories and their positional embeddings
+ combined_memory = torch.cat(memories_to_concatenate, dim=0)
+ combined_memory_positional_embeddings = torch.cat(memory_positional_embeddings_to_concatenate, dim=0)
+
+ # Step 5: Forward through the memory attention mechanism
+ conditioned_feature_map_flat = self.memory_attention(
+ current_vision_features=current_vision_features,
+ current_vision_position_embeddings=current_vision_positional_embeddings,
+ memory=combined_memory,
+ memory_posision_embeddings=combined_memory_positional_embeddings, # Corrected typo from API
+ num_object_pointer_tokens=num_object_pointer_tokens,
+ num_spatial_memory_tokens=num_spatial_memory_tokens,
+ )
+
+ # Reshape from (Batch, H*W, Channels) to (Batch, Channels, Height, Width)
+ conditioned_feature_map = (
+ conditioned_feature_map_flat.squeeze(1).permute(0, 2, 1).view(batch_size, num_channels, height, width)
+ )
+ return conditioned_feature_map
+
+ def _encode_new_memory(
+ self,
+ current_vision_feats: torch.Tensor,
+ pred_masks_high_res: torch.Tensor,
+ object_score_logits: torch.Tensor,
+ is_mask_from_pts: bool,
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+ """Encode the current image and its prediction into a memory feature."""
+ batch_size = current_vision_feats.size(1) # batch size on this frame
+ channels = self.hidden_dim
+ height, width = self.backbone_feature_sizes[-1] # top-level (lowest-resolution) feature size
+ # top-level feature, (HW)BC => BCHW
+ pix_feat = current_vision_feats.permute(1, 2, 0).view(batch_size, channels, height, width)
+ if is_mask_from_pts and not self.training:
+ # binarize the mask logits
+ mask_for_mem = (pred_masks_high_res > 0).to(pred_masks_high_res.dtype)
+ else:
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
+ # apply scale and bias terms to the sigmoid probabilities
+ mask_for_mem = mask_for_mem * self.config.sigmoid_scale_for_mem_enc
+ mask_for_mem = mask_for_mem + self.config.sigmoid_bias_for_mem_enc
+
+ maskmem_features, maskmem_pos_enc = self.memory_encoder(
+ pix_feat,
+ mask_for_mem,
+ )
+ # add a no-object embedding to the spatial memory to indicate that the frame
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
+ if self.occlusion_spatial_embedding_parameter is not None:
+ is_obj_appearing = (object_score_logits > 0).float()
+ maskmem_features += (1 - is_obj_appearing[..., None]) * self.occlusion_spatial_embedding_parameter[
+ ..., None, None
+ ].expand(*maskmem_features.shape)
+
+ maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype)
+ maskmem_features, maskmem_pos_enc = self.spatial_perceiver(maskmem_features, maskmem_pos_enc)
+ maskmem_features = maskmem_features.to(pred_masks_high_res.dtype)
+ maskmem_pos_enc = maskmem_pos_enc.to(pred_masks_high_res.dtype)
+
+ return maskmem_features, maskmem_pos_enc
+
+
+__all__ = [
+ "EdgeTamVideoMaskDecoderConfig",
+ "EdgeTamVideoPromptEncoderConfig",
+ "EdgeTamVideoConfig",
+ "EdgeTamVideoModel",
+ "EdgeTamVideoInferenceSession",
+ "EdgeTamVideoPreTrainedModel",
+]
diff --git a/src/transformers/models/gemma3n/configuration_gemma3n.py b/src/transformers/models/gemma3n/configuration_gemma3n.py
index 8e420bf27904..47b5b47d3630 100644
--- a/src/transformers/models/gemma3n/configuration_gemma3n.py
+++ b/src/transformers/models/gemma3n/configuration_gemma3n.py
@@ -551,8 +551,8 @@ def from_dict(cls, config_dict: dict[str, Any], **kwargs):
def to_dict(self) -> dict[str, Any]:
output = super().to_dict()
- output["num_classes"] = self.num_labels
- output["label_names"] = list(self.id2label.values())
+ output.setdefault("num_classes", self.num_labels)
+ output.setdefault("label_names", list(self.id2label.values()))
output.pop("id2label", None)
output.pop("label2id", None)
return output
diff --git a/src/transformers/models/sam2/configuration_sam2.py b/src/transformers/models/sam2/configuration_sam2.py
index 8a93f28d5a20..e14583181d38 100644
--- a/src/transformers/models/sam2/configuration_sam2.py
+++ b/src/transformers/models/sam2/configuration_sam2.py
@@ -379,8 +379,6 @@ class Sam2Config(PretrainedConfig):
Dictionary of configuration options used to initialize [`Sam2MaskDecoderConfig`].
initializer_range (`float`, *optional*, defaults to 0.02):
Standard deviation for parameter initialization.
- kwargs (*optional*):
- Dictionary of keyword arguments.
Example:
diff --git a/src/transformers/models/sam2_video/convert_sam2_video_to_hf.py b/src/transformers/models/sam2_video/convert_sam2_video_to_hf.py
index 322aa5507978..cc2ee0c7c612 100644
--- a/src/transformers/models/sam2_video/convert_sam2_video_to_hf.py
+++ b/src/transformers/models/sam2_video/convert_sam2_video_to_hf.py
@@ -190,7 +190,7 @@ def replace_keys(state_dict, config):
if re.match(output_vision_encoder_neck_pattern, key):
key = key.replace(".conv.", ".")
- # memory_encoder.out_proj.weight -> memory_encoder.projection.weight
+ # memory_encoder.o_proj.weight -> memory_encoder.projection.weight
if re.match(output_memory_encoder_projection_pattern, key):
key = key.replace(".o_proj.", ".projection.")
diff --git a/src/transformers/models/sam2_video/modeling_sam2_video.py b/src/transformers/models/sam2_video/modeling_sam2_video.py
index f4c1261d6779..caa07d1f63b5 100644
--- a/src/transformers/models/sam2_video/modeling_sam2_video.py
+++ b/src/transformers/models/sam2_video/modeling_sam2_video.py
@@ -134,8 +134,10 @@ def __init__(
dtype: Union[torch.dtype, str] = "float32",
max_vision_features_cache_size: int = 1,
):
- # store as a list to avoid double memory allocation with torch.cat when adding new frames
- self.processed_frames = list(video.to(video_storage_device, dtype=dtype)) if video is not None else None
+ # store as a dictionary to avoid double memory allocation with torch.cat when adding new frames
+ self.processed_frames = (
+ dict(enumerate(video.to(video_storage_device, dtype=dtype))) if video is not None else None
+ )
self.video_height = video_height
self.video_width = video_width
@@ -293,18 +295,21 @@ def get_output(
return value
# Video frame management
- def add_new_frame(self, pixel_values: torch.Tensor) -> int:
+ def add_new_frame(self, pixel_values: torch.Tensor, frame_idx: Optional[int] = None) -> int:
"""Add new frame with automatic device placement."""
pixel_values = pixel_values.to(self.video_storage_device, dtype=self.dtype, non_blocking=True)
if pixel_values.dim() == 4:
pixel_values = pixel_values.squeeze(0)
+ if frame_idx is None:
+ frame_idx = len(self.processed_frames) if self.processed_frames is not None else 0
+
if self.processed_frames is None:
- self.processed_frames = [pixel_values]
+ self.processed_frames = {frame_idx: pixel_values}
else:
- self.processed_frames.append(pixel_values)
+ self.processed_frames[frame_idx] = pixel_values
- return self.num_frames - 1
+ return frame_idx
def get_frame(self, frame_idx: int) -> torch.Tensor:
"""Get frame from video."""
@@ -1714,7 +1719,7 @@ def forward(
Whether to propagate in reverse.
"""
if frame is not None:
- frame_idx = inference_session.add_new_frame(frame)
+ frame_idx = inference_session.add_new_frame(frame, frame_idx)
if frame is not None and inference_session.get_obj_num() == 0:
raise ValueError("No objects are provided for tracking; please add inputs first.")
@@ -2097,6 +2102,195 @@ def _use_mask_as_output(
image_embeddings=high_res_features + [backbone_features],
)
+ def _gather_memory_frame_outputs(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ obj_idx: int,
+ frame_idx: int,
+ track_in_reverse_time: bool = False,
+ ) -> list[tuple[int, dict]]:
+ """
+ Get memory frames from conditioning and non-conditioning outputs.
+
+ Returns:
+ List of (relative_temporal_offset, output_data) tuples.
+ """
+ temporal_positions_and_previous_outputs = []
+
+ # Add conditioning frame outputs (no limit here, as is the case in the original checkpoints)
+ conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ if not conditioning_outputs:
+ raise ValueError(
+ "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame"
+ )
+
+ # Store (temporal_position, output_data) tuples
+ temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()]
+
+ # Add non-conditioning memory frames (up to self.num_maskmem - 1)
+ # These are typically frames tracked by the model without direct user input.
+ # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity.
+ for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1):
+ # relative_temporal_offset: how many frames before (or after if reversing) the current frame
+ if not track_in_reverse_time:
+ previous_frame_idx = frame_idx - relative_temporal_offset
+ else:
+ previous_frame_idx = frame_idx + relative_temporal_offset
+
+ # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
+ output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
+ previous_frame_idx, None
+ )
+
+ temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data))
+
+ return temporal_positions_and_previous_outputs
+
+ def _build_memory_attention_inputs(
+ self,
+ temporal_positions_and_previous_outputs: list[tuple[int, dict]],
+ device: torch.device,
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
+ """
+ Concatenate memory features and positional embeddings from previous frames.
+
+ Returns:
+ Tuple of (memories_to_concatenate, memory_positional_embeddings_to_concatenate).
+ """
+ memories_to_concatenate = []
+ memory_positional_embeddings_to_concatenate = []
+
+ for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs:
+ if prev_output_data is None:
+ continue # Skip if no output data for this temporal position (e.g., padding frames)
+
+ # Load memory features (potentially from CPU to GPU)
+ # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels)
+ memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True)
+ memories_to_concatenate.append(memory_features)
+
+ # Spatial positional encoding (potentially from CPU to GPU)
+ spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True)
+
+ # Add temporal positional encoding
+ # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim)
+ combined_memory_pos_embed = (
+ spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1]
+ )
+ memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed)
+
+ return memories_to_concatenate, memory_positional_embeddings_to_concatenate
+
+ def _get_object_pointers(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ obj_idx: int,
+ frame_idx: int,
+ num_total_frames: int,
+ device: torch.device,
+ track_in_reverse_time: bool = False,
+ streaming: bool = False,
+ ) -> tuple[list[int], list[torch.Tensor], int]:
+ """
+ Get object pointers and their positional embeddings from past frames.
+
+ Returns:
+ Tuple of (temporal_offsets, pointer_tokens, max_object_pointers_to_use).
+ """
+ temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1
+
+ # Determine max object pointers to use
+ if streaming:
+ max_object_pointers_to_use = self.config.max_object_pointers_in_encoder
+ else:
+ max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder)
+
+ temporal_offsets: list[int] = []
+ pointer_tokens: list[torch.Tensor] = []
+
+ # Add object pointers from selected conditioning frames
+ # Optionally, only include pointers from past frames during evaluation
+ conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ eligible_conditioning_outputs = conditioning_outputs
+ if not self.training:
+ eligible_conditioning_outputs = {
+ temporal_idx: out
+ for temporal_idx, out in conditioning_outputs.items()
+ if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx)
+ }
+
+ for temporal_idx, out_data in eligible_conditioning_outputs.items():
+ temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier
+ temporal_offsets.append(temporal_difference)
+ pointer_tokens.append(out_data["object_pointer"].to(device))
+
+ # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1)
+ for t_diff_offset in range(1, max_object_pointers_to_use):
+ ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset
+ if ref_frame_idx < 0 or (
+ not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames
+ ):
+ break # Stop if frame index is out of bounds
+
+ # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
+ out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
+ ref_frame_idx, None
+ )
+ if out_data is not None:
+ temporal_offsets.append(t_diff_offset)
+ pointer_tokens.append(out_data["object_pointer"].to(device))
+
+ return temporal_offsets, pointer_tokens, max_object_pointers_to_use
+
+ def _process_object_pointers(
+ self,
+ temporal_offsets: list[int],
+ pointer_tokens: list[torch.Tensor],
+ max_object_pointers_to_use: int,
+ batch_size: int,
+ num_channels: int,
+ device: torch.device,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Process object pointers and compute their positional embeddings.
+
+ Returns:
+ Tuple of (object_pointers, object_pointers_pos_embed).
+ """
+ if not pointer_tokens:
+ return None, None
+
+ # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels)
+ object_pointers = torch.stack(pointer_tokens, dim=0)
+
+ if self.config.enable_temporal_pos_encoding_for_object_pointers:
+ max_temporal_diff = float(max_object_pointers_to_use - 1)
+ # Determine dimensionality for temporal positional encoding of pointers
+ pointer_tpos_dim = num_channels
+
+ # Normalize temporal differences before sine PE calculation
+ normalized_temporal_diffs = (
+ torch.tensor(temporal_offsets, device=device, dtype=torch.float32) / max_temporal_diff
+ )
+ sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype)
+ projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe)
+ object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand(-1, batch_size, self.mem_dim)
+ else:
+ object_pointers_pos_embed = object_pointers.new_zeros(
+ len(temporal_offsets), batch_size, self.mem_dim, dtype=object_pointers.dtype
+ )
+
+ if self.mem_dim < num_channels:
+ # If memory dimension is smaller, reshape/split pointers and repeat positional encoding
+ num_splits = num_channels // self.mem_dim
+ object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim)
+ object_pointers = object_pointers.permute(0, 2, 1, 3).flatten(
+ 0, 1
+ ) # (SeqLen_ptr*num_splits, Batch, MemDim)
+ object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0)
+
+ return object_pointers, object_pointers_pos_embed
+
def _prepare_memory_conditioned_features(
self,
inference_session: Sam2VideoInferenceSession,
@@ -2157,135 +2351,9 @@ def _prepare_memory_conditioned_features(
)
return current_feature_map
- num_object_pointer_tokens = 0
- temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1
-
- # Step 1: Condition the visual features of the current frame on previous memories
- if not is_initial_conditioning_frame:
- # Retrieve memories encoded from previous frames
- memories_to_concatenate = []
- memory_positional_embeddings_to_concatenate = []
-
- # Ensure there are conditioning frame outputs to process
- conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
- if not conditioning_outputs:
- raise ValueError(
- "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame"
- )
-
- # Select a maximum number of temporally closest conditioning frames for cross-attention (no limit here, as is the case in the original checkpoints)
- # Store (temporal_position, output_data) tuples
- temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()]
-
- # Add non-conditioning memory frames (up to self.num_maskmem - 1)
- # These are typically frames tracked by the model without direct user input.
- # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity.
- for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1):
- # relative_temporal_offset: how many frames before (or after if reversing) the current frame
- if not track_in_reverse_time:
- previous_frame_idx = frame_idx - relative_temporal_offset
- else:
- previous_frame_idx = frame_idx + relative_temporal_offset
-
- # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
- output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
- previous_frame_idx, None
- )
-
- temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data))
-
- for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs:
- if prev_output_data is None:
- continue # Skip if no output data for this temporal position (e.g., padding frames)
-
- # Load memory features (potentially from CPU to GPU)
- # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels)
- memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True)
- memories_to_concatenate.append(memory_features)
-
- # Spatial positional encoding (potentially from CPU to GPU)
- spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True)
-
- # Add temporal positional encoding
- # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim)
- combined_memory_pos_embed = (
- spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1]
- )
- memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed)
-
- # Construct the list of past object pointers to be used in attention
- if streaming:
- max_object_pointers_to_use = self.config.max_object_pointers_in_encoder
- else:
- max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder)
- temporal_diff_and_pointers = []
-
- # Add object pointers from selected conditioning frames
- # Optionally, only include pointers from past frames during evaluation
- eligible_conditioning_outputs = conditioning_outputs
- if not self.training:
- eligible_conditioning_outputs = {
- temporal_idx: out
- for temporal_idx, out in conditioning_outputs.items()
- if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx)
- }
-
- for temporal_idx, out_data in eligible_conditioning_outputs.items():
- temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier
- temporal_diff_and_pointers.append((temporal_difference, out_data["object_pointer"]))
-
- # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1)
- for t_diff_offset in range(1, max_object_pointers_to_use):
- ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset
- if ref_frame_idx < 0 or (
- not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames
- ):
- break # Stop if frame index is out of bounds
-
- # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
- out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
- ref_frame_idx, None
- )
- if out_data is not None:
- temporal_diff_and_pointers.append((t_diff_offset, out_data["object_pointer"]))
-
- if temporal_diff_and_pointers:
- temporal_differences, object_pointers_list = zip(*temporal_diff_and_pointers)
- # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels)
- object_pointers = torch.stack(object_pointers_list, dim=0)
-
- if self.config.enable_temporal_pos_encoding_for_object_pointers:
- max_temporal_diff = float(max_object_pointers_to_use - 1)
- # Determine dimensionality for temporal positional encoding of pointers
- pointer_tpos_dim = num_channels
-
- # Normalize temporal differences before sine PE calculation
- normalized_temporal_diffs = (
- torch.tensor(temporal_differences, device=device, dtype=torch.float32) / max_temporal_diff
- )
- sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype)
- projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe)
- object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand(-1, batch_size, self.mem_dim)
- else:
- object_pointers_pos_embed = object_pointers.new_zeros(
- len(temporal_differences), batch_size, self.mem_dim, dtype=object_pointers.dtype
- )
-
- if self.mem_dim < num_channels:
- # If memory dimension is smaller, reshape/split pointers and repeat positional encoding
- num_splits = num_channels // self.mem_dim
- object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim)
- object_pointers = object_pointers.permute(0, 2, 1, 3).flatten(
- 0, 1
- ) # (SeqLen_ptr*num_splits, Batch, MemDim)
- object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0)
-
- memories_to_concatenate.append(object_pointers)
- memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed)
- num_object_pointer_tokens = object_pointers.shape[0]
- else:
+ # Step 1: Handle initial conditioning frames
+ if is_initial_conditioning_frame:
# For initial conditioning frames, no prior memory is used directly in this block.
- # The model might handle this with a special token or mechanism.
# If configured, directly add a learnable "no memory" embedding.
# current_vision_features has shape (SeqLen, Batch, Channels)
conditioned_feature_map_flat = current_vision_features + self.no_memory_embedding
@@ -2295,11 +2363,36 @@ def _prepare_memory_conditioned_features(
)
return conditioned_feature_map
- # Step 2: Concatenate all retrieved memories and their positional embeddings.
+ # Step 2: Get memory frames and concatenate their features
+ temporal_positions_and_previous_outputs = self._gather_memory_frame_outputs(
+ inference_session, obj_idx, frame_idx, track_in_reverse_time
+ )
+
+ memories_to_concatenate, memory_positional_embeddings_to_concatenate = self._build_memory_attention_inputs(
+ temporal_positions_and_previous_outputs, device
+ )
+
+ # Step 3: Get and process object pointers
+ temporal_offsets, pointer_tokens, max_object_pointers_to_use = self._get_object_pointers(
+ inference_session, obj_idx, frame_idx, num_total_frames, device, track_in_reverse_time, streaming
+ )
+
+ num_object_pointer_tokens = 0
+ if pointer_tokens:
+ object_pointers, object_pointers_pos_embed = self._process_object_pointers(
+ temporal_offsets, pointer_tokens, max_object_pointers_to_use, batch_size, num_channels, device
+ )
+
+ if object_pointers is not None:
+ memories_to_concatenate.append(object_pointers)
+ memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed)
+ num_object_pointer_tokens = object_pointers.shape[0]
+
+ # Step 4: Concatenate all retrieved memories and their positional embeddings
combined_memory = torch.cat(memories_to_concatenate, dim=0)
combined_memory_positional_embeddings = torch.cat(memory_positional_embeddings_to_concatenate, dim=0)
- # Step 3: Forward through the memory attention mechanism.
+ # Step 5: Forward through the memory attention mechanism
conditioned_feature_map_flat = self.memory_attention(
current_vision_features=current_vision_features,
current_vision_position_embeddings=current_vision_positional_embeddings,
diff --git a/src/transformers/models/sam2_video/modular_sam2_video.py b/src/transformers/models/sam2_video/modular_sam2_video.py
index 53e10998b2a7..fa0d6c21d5e6 100644
--- a/src/transformers/models/sam2_video/modular_sam2_video.py
+++ b/src/transformers/models/sam2_video/modular_sam2_video.py
@@ -403,8 +403,10 @@ def __init__(
dtype: Union[torch.dtype, str] = "float32",
max_vision_features_cache_size: int = 1,
):
- # store as a list to avoid double memory allocation with torch.cat when adding new frames
- self.processed_frames = list(video.to(video_storage_device, dtype=dtype)) if video is not None else None
+ # store as a dictionary to avoid double memory allocation with torch.cat when adding new frames
+ self.processed_frames = (
+ dict(enumerate(video.to(video_storage_device, dtype=dtype))) if video is not None else None
+ )
self.video_height = video_height
self.video_width = video_width
@@ -562,18 +564,21 @@ def get_output(
return value
# Video frame management
- def add_new_frame(self, pixel_values: torch.Tensor) -> int:
+ def add_new_frame(self, pixel_values: torch.Tensor, frame_idx: Optional[int] = None) -> int:
"""Add new frame with automatic device placement."""
pixel_values = pixel_values.to(self.video_storage_device, dtype=self.dtype, non_blocking=True)
if pixel_values.dim() == 4:
pixel_values = pixel_values.squeeze(0)
+ if frame_idx is None:
+ frame_idx = len(self.processed_frames) if self.processed_frames is not None else 0
+
if self.processed_frames is None:
- self.processed_frames = [pixel_values]
+ self.processed_frames = {frame_idx: pixel_values}
else:
- self.processed_frames.append(pixel_values)
+ self.processed_frames[frame_idx] = pixel_values
- return self.num_frames - 1
+ return frame_idx
def get_frame(self, frame_idx: int) -> torch.Tensor:
"""Get frame from video."""
@@ -1799,6 +1804,195 @@ def _use_mask_as_output(
image_embeddings=high_res_features + [backbone_features],
)
+ def _gather_memory_frame_outputs(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ obj_idx: int,
+ frame_idx: int,
+ track_in_reverse_time: bool = False,
+ ) -> list[tuple[int, dict]]:
+ """
+ Get memory frames from conditioning and non-conditioning outputs.
+
+ Returns:
+ List of (relative_temporal_offset, output_data) tuples.
+ """
+ temporal_positions_and_previous_outputs = []
+
+ # Add conditioning frame outputs (no limit here, as is the case in the original checkpoints)
+ conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ if not conditioning_outputs:
+ raise ValueError(
+ "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame"
+ )
+
+ # Store (temporal_position, output_data) tuples
+ temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()]
+
+ # Add non-conditioning memory frames (up to self.num_maskmem - 1)
+ # These are typically frames tracked by the model without direct user input.
+ # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity.
+ for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1):
+ # relative_temporal_offset: how many frames before (or after if reversing) the current frame
+ if not track_in_reverse_time:
+ previous_frame_idx = frame_idx - relative_temporal_offset
+ else:
+ previous_frame_idx = frame_idx + relative_temporal_offset
+
+ # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
+ output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
+ previous_frame_idx, None
+ )
+
+ temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data))
+
+ return temporal_positions_and_previous_outputs
+
+ def _build_memory_attention_inputs(
+ self,
+ temporal_positions_and_previous_outputs: list[tuple[int, dict]],
+ device: torch.device,
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
+ """
+ Concatenate memory features and positional embeddings from previous frames.
+
+ Returns:
+ Tuple of (memories_to_concatenate, memory_positional_embeddings_to_concatenate).
+ """
+ memories_to_concatenate = []
+ memory_positional_embeddings_to_concatenate = []
+
+ for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs:
+ if prev_output_data is None:
+ continue # Skip if no output data for this temporal position (e.g., padding frames)
+
+ # Load memory features (potentially from CPU to GPU)
+ # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels)
+ memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True)
+ memories_to_concatenate.append(memory_features)
+
+ # Spatial positional encoding (potentially from CPU to GPU)
+ spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True)
+
+ # Add temporal positional encoding
+ # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim)
+ combined_memory_pos_embed = (
+ spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1]
+ )
+ memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed)
+
+ return memories_to_concatenate, memory_positional_embeddings_to_concatenate
+
+ def _get_object_pointers(
+ self,
+ inference_session: Sam2VideoInferenceSession,
+ obj_idx: int,
+ frame_idx: int,
+ num_total_frames: int,
+ device: torch.device,
+ track_in_reverse_time: bool = False,
+ streaming: bool = False,
+ ) -> tuple[list[int], list[torch.Tensor], int]:
+ """
+ Get object pointers and their positional embeddings from past frames.
+
+ Returns:
+ Tuple of (temporal_offsets, pointer_tokens, max_object_pointers_to_use).
+ """
+ temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1
+
+ # Determine max object pointers to use
+ if streaming:
+ max_object_pointers_to_use = self.config.max_object_pointers_in_encoder
+ else:
+ max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder)
+
+ temporal_offsets: list[int] = []
+ pointer_tokens: list[torch.Tensor] = []
+
+ # Add object pointers from selected conditioning frames
+ # Optionally, only include pointers from past frames during evaluation
+ conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
+ eligible_conditioning_outputs = conditioning_outputs
+ if not self.training:
+ eligible_conditioning_outputs = {
+ temporal_idx: out
+ for temporal_idx, out in conditioning_outputs.items()
+ if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx)
+ }
+
+ for temporal_idx, out_data in eligible_conditioning_outputs.items():
+ temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier
+ temporal_offsets.append(temporal_difference)
+ pointer_tokens.append(out_data["object_pointer"].to(device))
+
+ # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1)
+ for t_diff_offset in range(1, max_object_pointers_to_use):
+ ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset
+ if ref_frame_idx < 0 or (
+ not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames
+ ):
+ break # Stop if frame index is out of bounds
+
+ # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
+ out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
+ ref_frame_idx, None
+ )
+ if out_data is not None:
+ temporal_offsets.append(t_diff_offset)
+ pointer_tokens.append(out_data["object_pointer"].to(device))
+
+ return temporal_offsets, pointer_tokens, max_object_pointers_to_use
+
+ def _process_object_pointers(
+ self,
+ temporal_offsets: list[int],
+ pointer_tokens: list[torch.Tensor],
+ max_object_pointers_to_use: int,
+ batch_size: int,
+ num_channels: int,
+ device: torch.device,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Process object pointers and compute their positional embeddings.
+
+ Returns:
+ Tuple of (object_pointers, object_pointers_pos_embed).
+ """
+ if not pointer_tokens:
+ return None, None
+
+ # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels)
+ object_pointers = torch.stack(pointer_tokens, dim=0)
+
+ if self.config.enable_temporal_pos_encoding_for_object_pointers:
+ max_temporal_diff = float(max_object_pointers_to_use - 1)
+ # Determine dimensionality for temporal positional encoding of pointers
+ pointer_tpos_dim = num_channels
+
+ # Normalize temporal differences before sine PE calculation
+ normalized_temporal_diffs = (
+ torch.tensor(temporal_offsets, device=device, dtype=torch.float32) / max_temporal_diff
+ )
+ sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype)
+ projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe)
+ object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand(-1, batch_size, self.mem_dim)
+ else:
+ object_pointers_pos_embed = object_pointers.new_zeros(
+ len(temporal_offsets), batch_size, self.mem_dim, dtype=object_pointers.dtype
+ )
+
+ if self.mem_dim < num_channels:
+ # If memory dimension is smaller, reshape/split pointers and repeat positional encoding
+ num_splits = num_channels // self.mem_dim
+ object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim)
+ object_pointers = object_pointers.permute(0, 2, 1, 3).flatten(
+ 0, 1
+ ) # (SeqLen_ptr*num_splits, Batch, MemDim)
+ object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0)
+
+ return object_pointers, object_pointers_pos_embed
+
def _prepare_memory_conditioned_features(
self,
inference_session: Sam2VideoInferenceSession,
@@ -1859,135 +2053,9 @@ def _prepare_memory_conditioned_features(
)
return current_feature_map
- num_object_pointer_tokens = 0
- temporal_position_sign_multiplier = -1 if track_in_reverse_time else 1
-
- # Step 1: Condition the visual features of the current frame on previous memories
- if not is_initial_conditioning_frame:
- # Retrieve memories encoded from previous frames
- memories_to_concatenate = []
- memory_positional_embeddings_to_concatenate = []
-
- # Ensure there are conditioning frame outputs to process
- conditioning_outputs = inference_session.output_dict_per_obj[obj_idx]["cond_frame_outputs"]
- if not conditioning_outputs:
- raise ValueError(
- "maskmem_features in conditioning outputs cannot be empty when not is_initial_conditioning_frame"
- )
-
- # Select a maximum number of temporally closest conditioning frames for cross-attention (no limit here, as is the case in the original checkpoints)
- # Store (temporal_position, output_data) tuples
- temporal_positions_and_previous_outputs = [(0, out) for out in conditioning_outputs.values()]
-
- # Add non-conditioning memory frames (up to self.num_maskmem - 1)
- # These are typically frames tracked by the model without direct user input.
- # Frames are selected with a stride, prioritizing the most recent ones. Here we only support stride = 1 for simplicity.
- for relative_temporal_offset in range(self.num_maskmem - 1, 0, -1):
- # relative_temporal_offset: how many frames before (or after if reversing) the current frame
- if not track_in_reverse_time:
- previous_frame_idx = frame_idx - relative_temporal_offset
- else:
- previous_frame_idx = frame_idx + relative_temporal_offset
-
- # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
- output_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
- previous_frame_idx, None
- )
-
- temporal_positions_and_previous_outputs.append((relative_temporal_offset, output_data))
-
- for relative_temporal_offset, prev_output_data in temporal_positions_and_previous_outputs:
- if prev_output_data is None:
- continue # Skip if no output data for this temporal position (e.g., padding frames)
-
- # Load memory features (potentially from CPU to GPU)
- # Features are flattened: (Batch, Channels, H, W) -> (H*W, Batch, Channels)
- memory_features = prev_output_data["maskmem_features"].to(device, non_blocking=True)
- memories_to_concatenate.append(memory_features)
-
- # Spatial positional encoding (potentially from CPU to GPU)
- spatial_memory_pos_embed = prev_output_data["maskmem_pos_enc"].to(device, non_blocking=True)
-
- # Add temporal positional encoding
- # self.memory_temporal_positional_encoding shape: (NumMaskMem, 1, 1, MemDim)
- combined_memory_pos_embed = (
- spatial_memory_pos_embed + self.memory_temporal_positional_encoding[relative_temporal_offset - 1]
- )
- memory_positional_embeddings_to_concatenate.append(combined_memory_pos_embed)
-
- # Construct the list of past object pointers to be used in attention
- if streaming:
- max_object_pointers_to_use = self.config.max_object_pointers_in_encoder
- else:
- max_object_pointers_to_use = min(num_total_frames, self.config.max_object_pointers_in_encoder)
- temporal_diff_and_pointers = []
-
- # Add object pointers from selected conditioning frames
- # Optionally, only include pointers from past frames during evaluation
- eligible_conditioning_outputs = conditioning_outputs
- if not self.training:
- eligible_conditioning_outputs = {
- temporal_idx: out
- for temporal_idx, out in conditioning_outputs.items()
- if (temporal_idx >= frame_idx if track_in_reverse_time else temporal_idx <= frame_idx)
- }
-
- for temporal_idx, out_data in eligible_conditioning_outputs.items():
- temporal_difference = (frame_idx - temporal_idx) * temporal_position_sign_multiplier
- temporal_diff_and_pointers.append((temporal_difference, out_data["object_pointer"]))
-
- # Add object pointers from non-conditioning frames (up to max_object_pointers_to_use - 1)
- for t_diff_offset in range(1, max_object_pointers_to_use):
- ref_frame_idx = frame_idx + t_diff_offset if track_in_reverse_time else frame_idx - t_diff_offset
- if ref_frame_idx < 0 or (
- not streaming and num_total_frames is not None and ref_frame_idx >= num_total_frames
- ):
- break # Stop if frame index is out of bounds
-
- # check if the output is already stored without using get_output to avoid unnecessary memory transfers between CPU and GPU
- out_data = inference_session.output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].get(
- ref_frame_idx, None
- )
- if out_data is not None:
- temporal_diff_and_pointers.append((t_diff_offset, out_data["object_pointer"]))
-
- if temporal_diff_and_pointers:
- temporal_differences, object_pointers_list = zip(*temporal_diff_and_pointers)
- # Stack object pointers: List of (Batch, Channels) -> (SeqLen_ptr, Batch, Channels)
- object_pointers = torch.stack(object_pointers_list, dim=0)
-
- if self.config.enable_temporal_pos_encoding_for_object_pointers:
- max_temporal_diff = float(max_object_pointers_to_use - 1)
- # Determine dimensionality for temporal positional encoding of pointers
- pointer_tpos_dim = num_channels
-
- # Normalize temporal differences before sine PE calculation
- normalized_temporal_diffs = (
- torch.tensor(temporal_differences, device=device, dtype=torch.float32) / max_temporal_diff
- )
- sine_pe = get_1d_sine_pe(normalized_temporal_diffs, dim=pointer_tpos_dim).to(object_pointers.dtype)
- projected_sine_pe = self.temporal_positional_encoding_projection_layer(sine_pe)
- object_pointers_pos_embed = projected_sine_pe.unsqueeze(1).expand(-1, batch_size, self.mem_dim)
- else:
- object_pointers_pos_embed = object_pointers.new_zeros(
- len(temporal_differences), batch_size, self.mem_dim, dtype=object_pointers.dtype
- )
-
- if self.mem_dim < num_channels:
- # If memory dimension is smaller, reshape/split pointers and repeat positional encoding
- num_splits = num_channels // self.mem_dim
- object_pointers = object_pointers.reshape(-1, batch_size, num_splits, self.mem_dim)
- object_pointers = object_pointers.permute(0, 2, 1, 3).flatten(
- 0, 1
- ) # (SeqLen_ptr*num_splits, Batch, MemDim)
- object_pointers_pos_embed = object_pointers_pos_embed.repeat_interleave(num_splits, dim=0)
-
- memories_to_concatenate.append(object_pointers)
- memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed)
- num_object_pointer_tokens = object_pointers.shape[0]
- else:
+ # Step 1: Handle initial conditioning frames
+ if is_initial_conditioning_frame:
# For initial conditioning frames, no prior memory is used directly in this block.
- # The model might handle this with a special token or mechanism.
# If configured, directly add a learnable "no memory" embedding.
# current_vision_features has shape (SeqLen, Batch, Channels)
conditioned_feature_map_flat = current_vision_features + self.no_memory_embedding
@@ -1997,11 +2065,36 @@ def _prepare_memory_conditioned_features(
)
return conditioned_feature_map
- # Step 2: Concatenate all retrieved memories and their positional embeddings.
+ # Step 2: Get memory frames and concatenate their features
+ temporal_positions_and_previous_outputs = self._gather_memory_frame_outputs(
+ inference_session, obj_idx, frame_idx, track_in_reverse_time
+ )
+
+ memories_to_concatenate, memory_positional_embeddings_to_concatenate = self._build_memory_attention_inputs(
+ temporal_positions_and_previous_outputs, device
+ )
+
+ # Step 3: Get and process object pointers
+ temporal_offsets, pointer_tokens, max_object_pointers_to_use = self._get_object_pointers(
+ inference_session, obj_idx, frame_idx, num_total_frames, device, track_in_reverse_time, streaming
+ )
+
+ num_object_pointer_tokens = 0
+ if pointer_tokens:
+ object_pointers, object_pointers_pos_embed = self._process_object_pointers(
+ temporal_offsets, pointer_tokens, max_object_pointers_to_use, batch_size, num_channels, device
+ )
+
+ if object_pointers is not None:
+ memories_to_concatenate.append(object_pointers)
+ memory_positional_embeddings_to_concatenate.append(object_pointers_pos_embed)
+ num_object_pointer_tokens = object_pointers.shape[0]
+
+ # Step 4: Concatenate all retrieved memories and their positional embeddings
combined_memory = torch.cat(memories_to_concatenate, dim=0)
combined_memory_positional_embeddings = torch.cat(memory_positional_embeddings_to_concatenate, dim=0)
- # Step 3: Forward through the memory attention mechanism.
+ # Step 5: Forward through the memory attention mechanism
conditioned_feature_map_flat = self.memory_attention(
current_vision_features=current_vision_features,
current_vision_position_embeddings=current_vision_positional_embeddings,
@@ -2211,7 +2304,7 @@ def forward(
Whether to propagate in reverse.
"""
if frame is not None:
- frame_idx = inference_session.add_new_frame(frame)
+ frame_idx = inference_session.add_new_frame(frame, frame_idx)
if frame is not None and inference_session.get_obj_num() == 0:
raise ValueError("No objects are provided for tracking; please add inputs first.")
diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py
index 24142232241f..34e640ade8bf 100644
--- a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py
+++ b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py
@@ -121,8 +121,8 @@ def from_dict(cls, config_dict: dict[str, Any], **kwargs):
def to_dict(self) -> dict[str, Any]:
output = super().to_dict()
- output["num_classes"] = self.num_labels
- output["label_names"] = list(self.id2label.values())
+ output.setdefault("num_classes", self.num_labels)
+ output.setdefault("label_names", list(self.id2label.values()))
output.pop("id2label", None)
output.pop("label2id", None)
return output
diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
index cfc3c1c104d3..d388ff05297f 100644
--- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
+++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
@@ -160,6 +160,7 @@ def __init__(self, config: TimmWrapperConfig):
super().__init__(config)
# using num_classes=0 to avoid creating classification head
extra_init_kwargs = config.model_args or {}
+ self.features_only = extra_init_kwargs.get("features_only", False)
self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs)
self.post_init()
@@ -233,20 +234,25 @@ def forward(
pixel_values = pixel_values.to(self.device, self.dtype)
- if output_hidden_states:
- # to enable hidden states selection
- if isinstance(output_hidden_states, (list, tuple)):
- kwargs["indices"] = output_hidden_states
- last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs)
- else:
- last_hidden_state = self.timm_model.forward_features(pixel_values, **kwargs)
- hidden_states = None
-
- if do_pooling:
- # classification head is not created, applying pooling only
- pooler_output = self.timm_model.forward_head(last_hidden_state)
- else:
+ if self.features_only:
+ last_hidden_state = self.timm_model.forward(pixel_values, **kwargs)
+ hidden_states = last_hidden_state if output_hidden_states else None
pooler_output = None
+ else:
+ if output_hidden_states:
+ # to enable hidden states selection
+ if isinstance(output_hidden_states, (list, tuple)):
+ kwargs["indices"] = output_hidden_states
+ last_hidden_state, hidden_states = self.timm_model.forward_intermediates(pixel_values, **kwargs)
+ else:
+ last_hidden_state = self.timm_model.forward_features(pixel_values, **kwargs)
+ hidden_states = None
+
+ if do_pooling:
+ # classification head is not created, applying pooling only
+ pooler_output = self.timm_model.forward_head(last_hidden_state)
+ else:
+ pooler_output = None
if not return_dict:
outputs = (last_hidden_state, pooler_output, hidden_states)
diff --git a/tests/models/edgetam/__init__.py b/tests/models/edgetam/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/edgetam/test_modeling_edgetam.py b/tests/models/edgetam/test_modeling_edgetam.py
new file mode 100644
index 000000000000..701642a43d41
--- /dev/null
+++ b/tests/models/edgetam/test_modeling_edgetam.py
@@ -0,0 +1,734 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch EDGETAM model."""
+
+import gc
+import tempfile
+import unittest
+
+import requests
+
+from transformers import (
+ EdgeTamConfig,
+ EdgeTamMaskDecoderConfig,
+ EdgeTamPromptEncoderConfig,
+ EdgeTamVisionConfig,
+ Sam2Processor,
+ pipeline,
+)
+from transformers.testing_utils import (
+ backend_empty_cache,
+ require_torch,
+ slow,
+ torch_device,
+)
+from transformers.utils import is_torch_available, is_vision_available
+from transformers.video_utils import load_video
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import AutoConfig, EdgeTamModel, Sam2Processor
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+class EdgeTamPromptEncoderTester:
+ def __init__(
+ self,
+ hidden_size=32,
+ input_image_size=128,
+ patch_size=16,
+ mask_input_channels=8,
+ num_point_embeddings=4,
+ hidden_act="gelu",
+ ):
+ self.hidden_size = hidden_size
+ self.input_image_size = input_image_size
+ self.patch_size = patch_size
+ self.mask_input_channels = mask_input_channels
+ self.num_point_embeddings = num_point_embeddings
+ self.hidden_act = hidden_act
+
+ def get_config(self):
+ return EdgeTamPromptEncoderConfig(
+ image_size=self.input_image_size,
+ patch_size=self.patch_size,
+ mask_input_channels=self.mask_input_channels,
+ hidden_size=self.hidden_size,
+ num_point_embeddings=self.num_point_embeddings,
+ hidden_act=self.hidden_act,
+ )
+
+ def prepare_config_and_inputs(self):
+ dummy_points = floats_tensor([self.batch_size, 3, 2])
+ config = self.get_config()
+
+ return config, dummy_points
+
+
+class EdgeTamMaskDecoderTester:
+ def __init__(
+ self,
+ hidden_size=32,
+ hidden_act="relu",
+ mlp_dim=64,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ attention_downsample_rate=2,
+ num_multimask_outputs=3,
+ iou_head_depth=3,
+ iou_head_hidden_dim=32,
+ ):
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.mlp_dim = mlp_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.attention_downsample_rate = attention_downsample_rate
+ self.num_multimask_outputs = num_multimask_outputs
+ self.iou_head_depth = iou_head_depth
+ self.iou_head_hidden_dim = iou_head_hidden_dim
+
+ def get_config(self):
+ return EdgeTamMaskDecoderConfig(
+ hidden_size=self.hidden_size,
+ hidden_act=self.hidden_act,
+ mlp_dim=self.mlp_dim,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ attention_downsample_rate=self.attention_downsample_rate,
+ num_multimask_outputs=self.num_multimask_outputs,
+ iou_head_depth=self.iou_head_depth,
+ iou_head_hidden_dim=self.iou_head_hidden_dim,
+ )
+
+ def prepare_config_and_inputs(self):
+ config = self.get_config()
+
+ dummy_inputs = {
+ "image_embedding": floats_tensor([self.batch_size, self.hidden_size]),
+ }
+
+ return config, dummy_inputs
+
+
+class EdgeTamModelTester:
+ def __init__(
+ self,
+ parent,
+ num_channels=3,
+ image_size=128,
+ hidden_size=12,
+ patch_kernel_size=7,
+ patch_stride=4,
+ patch_padding=3,
+ dim_mul=2.0,
+ backbone_channel_list=[96, 48, 24, 12],
+ backbone_feature_sizes=[[32, 32], [16, 16], [8, 8]],
+ fpn_hidden_size=32,
+ memory_encoder_hidden_size=32,
+ batch_size=2,
+ is_training=False,
+ ):
+ self.parent = parent
+ self.image_size = image_size
+ self.hidden_size = hidden_size
+ self.patch_kernel_size = patch_kernel_size
+ self.patch_stride = patch_stride
+ self.patch_padding = patch_padding
+ self.dim_mul = dim_mul
+ self.backbone_channel_list = backbone_channel_list
+ self.backbone_feature_sizes = backbone_feature_sizes
+ self.fpn_hidden_size = fpn_hidden_size
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.memory_encoder_hidden_size = memory_encoder_hidden_size
+
+ self.prompt_encoder_tester = EdgeTamPromptEncoderTester()
+ self.mask_decoder_tester = EdgeTamMaskDecoderTester()
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ config = self.get_config()
+
+ return config, pixel_values
+
+ def get_config(self):
+ vision_config = EdgeTamVisionConfig(
+ backbone_config=AutoConfig.from_pretrained(
+ "timm/repvit_m1.dist_in1k",
+ model_args={
+ "in_chans": 3,
+ "features_only": True,
+ "out_indices": (0, 1, 2, 3),
+ "embed_dim": self.backbone_channel_list[::-1],
+ },
+ ),
+ backbone_channel_list=self.backbone_channel_list,
+ backbone_feature_sizes=self.backbone_feature_sizes,
+ fpn_hidden_size=self.fpn_hidden_size,
+ )
+
+ prompt_encoder_config = self.prompt_encoder_tester.get_config()
+
+ mask_decoder_config = self.mask_decoder_tester.get_config()
+
+ return EdgeTamConfig(
+ vision_config=vision_config,
+ prompt_encoder_config=prompt_encoder_config,
+ mask_decoder_config=mask_decoder_config,
+ memory_attention_hidden_size=self.hidden_size,
+ memory_encoder_hidden_size=self.memory_encoder_hidden_size,
+ image_size=self.image_size,
+ mask_downsampler_embed_dim=32,
+ memory_fuser_embed_dim=32,
+ memory_attention_num_layers=1,
+ memory_attention_feed_forward_hidden_size=32,
+ )
+
+ def create_and_check_model(self, config, pixel_values):
+ model = EdgeTamModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(pixel_values)
+ self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3))
+ self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class EdgeTamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (EdgeTamModel,) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {"feature-extraction": EdgeTamModel, "mask-generation": EdgeTamModel} if is_torch_available() else {}
+ )
+ fx_compatible = False
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+ test_torchscript = False
+ _is_composite = True
+
+ def setUp(self):
+ self.model_tester = EdgeTamModelTester(self)
+ common_properties = ["initializer_range"]
+ self.config_tester = ConfigTester(
+ self, config_class=EdgeTamConfig, has_text_modality=False, common_properties=common_properties
+ )
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="Timm model does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="Can't get or set embeddings for Timm model")
+ def test_model_get_set_embeddings(self):
+ pass
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ # Override as EdgeTamModel doesn't have hidden states
+ def flash_attn_inference_equivalence(self, attn_implementation: str, padding_side: str):
+ r"""
+ Tests the equivalence between the eager and flash attention implementations.
+ This test is only for inference and runs with `torch_dtype=torch.bfloat16`.
+ """
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_model_classes:
+ if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
+ attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
+ ):
+ self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model.main_input_name][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ dummy_attention_mask = dummy_attention_mask[:1]
+ if padding_side == "left":
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+ else:
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+ if model.config.is_encoder_decoder:
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ else:
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = outputs.vision_hidden_states[-1]
+ logits_fa = outputs_fa.vision_hidden_states[-1]
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ if model.config.is_encoder_decoder:
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+ else:
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = outputs.vision_hidden_states[-1]
+ logits_fa = outputs_fa.vision_hidden_states[-1]
+
+ if padding_side == "left":
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, **other_inputs)
+ else:
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
+ # Override as diffence slightly higher than the threshold
+ # def test_batching_equivalence(self, atol=5e-4, rtol=5e-4):
+ # super().test_batching_equivalence(atol=atol, rtol=rtol)
+
+ @unittest.skip(reason="TimmWrapperModel does not support an attention implementation")
+ def test_can_set_attention_dynamically_composite_model(self):
+ pass
+
+ @unittest.skip(reason="vision_hidden_states from TimmWrapperModel")
+ def test_hidden_states_output(self):
+ pass
+
+ @unittest.skip(reason="Timm weights cannot be fully constructed in _init_weights")
+ def test_can_init_all_missing_weights(self):
+ pass
+
+ @unittest.skip(reason="Timm weights cannot be fully constructed in _init_weights")
+ def test_initialization(self):
+ pass
+
+ @unittest.skip(
+ reason="TIMM's attention implementation is self configured and won't raise ValueError on global attention implementation."
+ )
+ def test_flash_attn_2_can_dispatch_composite_models(self):
+ pass
+
+ @unittest.skip("TimmWrapperModel cannot be tested with meta device")
+ def test_can_be_initialized_on_meta(self):
+ pass
+
+ @unittest.skip("TimmWrapperModel cannot be tested with meta device")
+ def test_can_load_with_meta_device_context_manager(self):
+ pass
+
+ ## Skip flash attention releated tests below
+ ## correct configuration:
+ ## from_pretrained(model_id, attn_implementation={"text_config": "flash_attention_2", "vision_config": "eager"}
+ @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
+ def test_eager_matches_fa2_generate(self):
+ pass
+
+ @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
+ def test_flash_attn_2_fp32_ln(self):
+ pass
+
+ @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
+ def test_flash_attn_2_from_config(self):
+ pass
+
+ @unittest.skip("SDPA test is not configured correctly as we need to configure vision/timm model to 'eager'.")
+ def test_eager_matches_sdpa_generate_with_dynamic_cache(self):
+ pass
+
+ @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
+ def test_flash_attn_2_inference_equivalence_right_padding(self):
+ pass
+
+ @unittest.skip("SDPA test is not configured correctly as we need to configure vision/timm model to 'eager'.")
+ def test_eager_matches_sdpa_generate(self):
+ pass
+
+ @unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
+ def test_flash_attn_2_inference_equivalence(self):
+ pass
+
+ @unittest.skip("EdgeTAM does not have language_model, vision_tower, multi_modal_projector.")
+ def test_sdpa_can_dispatch_composite_models(self):
+ pass
+
+ @unittest.skip("Cannot set `output_attentions` for timm models.")
+ def test_attention_outputs(self):
+ pass
+
+ @unittest.skip("Cannot set `output_attentions` for timm models.")
+ def test_retain_grad_hidden_states_attentions(self):
+ pass
+
+ @unittest.skip("Cannot set `output_attentions` for timm models.")
+ def test_generate_compilation_all_outputs(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ model_name = "yonigozlan/EdgeTAM-hf"
+ model = EdgeTamModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+ def test_sdpa_can_compile_dynamic(self):
+ self.skipTest(reason="EDGETAM model can't be compiled dynamic yet")
+
+
+def prepare_image():
+ img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+ return raw_image
+
+
+def prepare_groceries_image():
+ img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/groceries.jpg"
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+ return raw_image
+
+
+def prepare_dog_img():
+ img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+ return raw_image
+
+
+def prepare_video():
+ video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"
+ raw_video, _ = load_video(video_url)
+ return raw_video
+
+
+@slow
+class EdgeTamModelIntegrationTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.model = EdgeTamModel.from_pretrained("yonigozlan/EdgeTAM-hf").to(torch.float32)
+ self.processor = Sam2Processor.from_pretrained("yonigozlan/EdgeTAM-hf")
+ self.model.to(torch_device)
+ self.model.eval()
+
+ def tearDown(self):
+ super().tearDown()
+ # clean-up as much as possible GPU memory occupied by PyTorch
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_inference_mask_generation_one_point_multimask(self):
+ raw_image = prepare_image()
+ input_points = [[[[500, 375]]]]
+ input_labels = [[[1]]]
+
+ inputs = self.processor(
+ images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
+ ).to(torch_device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+ self.assertEqual(outputs.iou_scores.shape, (1, 1, 3))
+ self.assertEqual(outputs.pred_masks.shape, (1, 1, 3, 256, 256))
+ sorted_indices = torch.argsort(outputs.iou_scores.squeeze(), descending=True)
+ scores = outputs.iou_scores.squeeze()[sorted_indices]
+ masks_logits = outputs.pred_masks.squeeze()[sorted_indices][0, :3, :3]
+ torch.testing.assert_close(
+ scores, torch.tensor([0.7621, 0.4859, 0.0461]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ masks_logits,
+ torch.tensor(
+ [[-19.5483, -22.3549, -26.0962], [-18.1821, -23.4761, -24.2262], [-20.3549, -24.5518, -22.7232]]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ def test_inference_mask_generation_one_point_no_multimask(self):
+ raw_image = prepare_image()
+ input_points = [[[[500, 375]]]]
+ input_labels = [[[1]]]
+
+ inputs = self.processor(
+ images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
+ ).to(torch_device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs, multimask_output=False)
+ self.assertEqual(outputs.iou_scores.shape, (1, 1, 1))
+ self.assertEqual(outputs.pred_masks.shape, (1, 1, 1, 256, 256))
+ scores = outputs.iou_scores.squeeze((0, 1))
+ masks_logits = outputs.pred_masks.squeeze((0, 1))[0, :3, :3]
+ torch.testing.assert_close(scores, torch.tensor([0.7621]).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(
+ masks_logits,
+ torch.tensor(
+ [[-19.5483, -22.3549, -26.0962], [-18.1821, -23.4761, -24.2262], [-20.3549, -24.5518, -22.7232]]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ def test_inference_mask_generation_batched_images_multi_points(self):
+ raw_image1 = prepare_image()
+ raw_image2 = prepare_dog_img()
+ input_points = [[[[500, 375]]], [[[770, 200], [730, 120]]]]
+ input_labels = [[[1]], [[1, 0]]]
+
+ inputs = self.processor(
+ images=[raw_image1, raw_image2], input_points=input_points, input_labels=input_labels, return_tensors="pt"
+ ).to(torch_device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+ self.assertEqual(outputs.iou_scores.shape, (2, 1, 3))
+ self.assertEqual(outputs.pred_masks.shape, (2, 1, 3, 256, 256))
+
+ sorted_indices = torch.argsort(outputs.iou_scores[0].squeeze(), descending=True)
+ scores1 = outputs.iou_scores[0].squeeze()[sorted_indices]
+ masks_logits1 = outputs.pred_masks[0].squeeze()[sorted_indices][0, :3, :3]
+ sorted_indices = torch.argsort(outputs.iou_scores[1].squeeze(), descending=True)
+ scores2 = outputs.iou_scores[1].squeeze()[sorted_indices]
+ masks_logits2 = outputs.pred_masks[1].squeeze()[sorted_indices][0, :3, :3]
+ torch.testing.assert_close(
+ scores1, torch.tensor([0.7490, 0.4685, 0.0463]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ masks_logits1,
+ torch.tensor(
+ [[-19.1423, -21.6488, -25.6816], [-17.8018, -22.6512, -23.5699], [-19.9140, -23.6919, -22.3147]]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ torch.testing.assert_close(
+ scores2, torch.tensor([0.7225, 0.6515, 0.6350]).to(torch_device), atol=1e-4, rtol=1e-4
+ )
+ torch.testing.assert_close(
+ masks_logits2,
+ torch.tensor([[-8.8259, -7.7961, -9.3665], [-8.2648, -8.7771, -9.1390], [-9.5951, -8.3995, -9.0599]]).to(
+ torch_device
+ ),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ def test_inference_mask_generation_batched_images_batched_points_multi_points(self):
+ raw_image1 = prepare_image()
+ raw_image2 = prepare_groceries_image()
+ input_points = [[[[500, 375]], [[650, 750]]], [[[400, 300]], [[630, 300], [550, 300]]]]
+ input_labels = [[[1], [1]], [[1], [1, 1]]]
+ inputs = self.processor(
+ images=[raw_image1, raw_image2], input_points=input_points, input_labels=input_labels, return_tensors="pt"
+ ).to(torch_device)
+ with torch.no_grad():
+ outputs = self.model(**inputs, multimask_output=False)
+ self.assertEqual(outputs.iou_scores.shape, (2, 2, 1))
+ self.assertEqual(outputs.pred_masks.shape, (2, 2, 1, 256, 256))
+ torch.testing.assert_close(
+ outputs.iou_scores,
+ torch.tensor([[[0.7490], [0.9397]], [[0.7952], [0.8723]]]).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+ torch.testing.assert_close(
+ outputs.pred_masks[:, :, :, :2, :2],
+ torch.tensor(
+ [
+ [[[[-19.1423, -21.6488], [-17.8018, -22.6512]]], [[[-7.1591, -9.8201], [-7.4133, -9.2781]]]],
+ [[[[-16.7645, -15.2790], [-16.1805, -16.2937]]], [[[-8.5934, -8.4215], [-8.1873, -8.3722]]]],
+ ]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ def test_inference_batched_images_batched_boxes(self):
+ raw_image1 = prepare_image()
+ raw_image2 = prepare_groceries_image()
+ input_boxes = [
+ [[75, 275, 1725, 850], [425, 600, 700, 875], [1375, 550, 1650, 800], [1240, 675, 1400, 750]],
+ [[450, 170, 520, 350], [350, 190, 450, 350], [500, 170, 580, 350], [580, 170, 640, 350]],
+ ]
+ inputs = self.processor(images=[raw_image1, raw_image2], input_boxes=input_boxes, return_tensors="pt").to(
+ torch_device
+ )
+ with torch.no_grad():
+ outputs = self.model(**inputs, multimask_output=False)
+ self.assertEqual(outputs.iou_scores.shape, (2, 4, 1))
+ self.assertEqual(outputs.pred_masks.shape, (2, 4, 1, 256, 256))
+ torch.testing.assert_close(
+ outputs.iou_scores,
+ torch.tensor([[[0.9773], [0.9415], [0.9683], [0.8792]], [[0.9721], [0.9852], [0.9812], [0.9760]]]).to(
+ torch_device
+ ),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+ torch.testing.assert_close(
+ outputs.pred_masks[:, :, :, :2, :2],
+ torch.tensor(
+ [
+ [
+ [[[-12.6412, -12.0553], [-11.8415, -13.1696]]],
+ [[[-16.0378, -19.9641], [-15.4939, -19.0260]]],
+ [[[-18.8254, -23.6185], [-17.7889, -23.2116]]],
+ [[[-25.7024, -29.8722], [-22.9264, -30.0557]]],
+ ],
+ [
+ [[[-19.0264, -17.0396], [-16.9458, -16.3287]]],
+ [[[-20.9671, -19.2132], [-18.5827, -18.0511]]],
+ [[[-22.4642, -19.7389], [-19.4541, -19.4717]]],
+ [[[-21.9226, -18.6297], [-18.9272, -18.8151]]],
+ ],
+ ]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ def test_inference_mask_generation_from_existing_points_and_mask(self):
+ raw_image = prepare_image()
+ input_points = [[[[500, 375]]]]
+ input_labels = [[[1]]]
+ original_inputs = self.processor(
+ images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
+ ).to(torch_device)
+ with torch.no_grad():
+ outputs = self.model(**original_inputs)
+
+ # best mask to use as input for new points
+ mask_input = outputs.pred_masks[:, :, torch.argmax(outputs.iou_scores)]
+
+ new_input_points = [[[[500, 375], [1125, 625]]]]
+ new_input_labels = [[[1, 1]]]
+ inputs = self.processor(
+ input_points=new_input_points,
+ input_labels=new_input_labels,
+ original_sizes=original_inputs["original_sizes"],
+ return_tensors="pt",
+ ).to(torch_device)
+ with torch.no_grad():
+ outputs = self.model(
+ **inputs,
+ input_masks=mask_input,
+ image_embeddings=outputs.image_embeddings,
+ multimask_output=False,
+ )
+
+ self.assertEqual(outputs.iou_scores.shape, (1, 1, 1))
+ self.assertEqual(outputs.pred_masks.shape, (1, 1, 1, 256, 256))
+ scores = outputs.iou_scores.squeeze((0, 1))
+ masks_logits = outputs.pred_masks.squeeze((0, 1))[0, :3, :3]
+ torch.testing.assert_close(scores, torch.tensor([0.9431]).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(
+ masks_logits,
+ torch.tensor([[-4.1968, -4.9034, -6.0680], [-4.4053, -5.1200, -5.8580], [-4.3920, -5.5096, -5.8166]]).to(
+ torch_device
+ ),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ # with negative point
+ new_input_points = [[[[500, 375], [1125, 625]]]]
+ new_input_labels = [[[1, 0]]]
+ inputs = self.processor(
+ input_points=new_input_points,
+ input_labels=new_input_labels,
+ original_sizes=original_inputs["original_sizes"],
+ return_tensors="pt",
+ ).to(torch_device)
+ with torch.no_grad():
+ outputs = self.model(
+ **inputs,
+ input_masks=mask_input,
+ image_embeddings=outputs.image_embeddings,
+ multimask_output=False,
+ )
+ self.assertEqual(outputs.iou_scores.shape, (1, 1, 1))
+ self.assertEqual(outputs.pred_masks.shape, (1, 1, 1, 256, 256))
+ scores = outputs.iou_scores.squeeze((0, 1))
+ masks_logits = outputs.pred_masks.squeeze((0, 1))[0, :3, :3]
+ torch.testing.assert_close(scores, torch.tensor([0.9695]).to(torch_device), atol=1e-4, rtol=1e-4)
+ torch.testing.assert_close(
+ masks_logits,
+ torch.tensor(
+ [[-14.3212, -15.4295, -17.4482], [-13.2246, -15.9468, -17.1341], [-15.1678, -16.4498, -14.7385]]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ def test_dummy_pipeline_generation(self):
+ generator = pipeline("mask-generation", model="yonigozlan/EdgeTAM-hf", device=torch_device)
+ raw_image = prepare_image()
+
+ _ = generator(raw_image, points_per_batch=64)
diff --git a/tests/models/edgetam_video/__init__.py b/tests/models/edgetam_video/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/edgetam_video/test_modeling_edgetam_video.py b/tests/models/edgetam_video/test_modeling_edgetam_video.py
new file mode 100644
index 000000000000..a2ad383351d2
--- /dev/null
+++ b/tests/models/edgetam_video/test_modeling_edgetam_video.py
@@ -0,0 +1,507 @@
+# coding=utf-8
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch SAM2 model."""
+
+import gc
+import unittest
+
+import requests
+
+from transformers.testing_utils import (
+ backend_empty_cache,
+ slow,
+ torch_device,
+)
+from transformers.utils import is_torch_available, is_vision_available
+from transformers.video_utils import load_video
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import EdgeTamVideoModel, Sam2VideoProcessor
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+def prepare_image():
+ img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+ return raw_image
+
+
+def prepare_groceries_image():
+ img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/groceries.jpg"
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+ return raw_image
+
+
+def prepare_dog_img():
+ img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
+ return raw_image
+
+
+def prepare_video():
+ video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"
+ raw_video, _ = load_video(video_url)
+ return raw_video
+
+
+@slow
+class EdgeTamVideoModelIntegrationTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.video_model = EdgeTamVideoModel.from_pretrained("yonigozlan/EdgeTAM-hf").to(torch.float32)
+ self.processor = Sam2VideoProcessor.from_pretrained("yonigozlan/EdgeTAM-hf")
+ self.video_model.to(torch_device)
+ self.video_model.eval()
+
+ def tearDown(self):
+ super().tearDown()
+ # clean-up as much as possible GPU memory occupied by PyTorch
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_inference_mask_generation_video_one_point(self):
+ raw_video = prepare_video()
+ inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
+ ann_frame_idx = 0 # the frame index we interact with
+ ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
+
+ self.processor.add_inputs_to_inference_session(
+ inference_session=inference_session,
+ frame_idx=ann_frame_idx,
+ obj_ids=ann_obj_id,
+ input_points=[[[[210, 350]]]],
+ input_labels=[[[1]]],
+ )
+ outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
+ low_res_masks = outputs.pred_masks
+ self.assertEqual(low_res_masks.shape, (1, 1, 256, 256))
+ video_res_masks = self.processor.post_process_masks([low_res_masks], [raw_video.shape[-3:-1]], binarize=False)[
+ 0
+ ]
+ self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ torch.testing.assert_close(
+ video_res_masks[0, 0, :3, :3],
+ torch.tensor(
+ [[-28.3880, -28.3880, -27.9277], [-27.5260, -27.5260, -27.2455], [-25.5902, -25.5902, -25.7136]]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ # test propagate in video frames
+ frames = []
+ for sam2_video_output in self.video_model.propagate_in_video_iterator(
+ inference_session=inference_session,
+ max_frame_num_to_track=2,
+ ):
+ video_res_masks = self.processor.post_process_masks(
+ [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0]
+ frames.append(video_res_masks)
+ frames = torch.stack(frames, dim=0)
+ self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ torch.testing.assert_close(
+ frames[:3, :, :, :2, :2],
+ torch.tensor(
+ [
+ [[[[-28.3880, -28.3880], [-27.5260, -27.5260]]]],
+ [[[[-15.3350, -15.3350], [-15.0002, -15.0002]]]],
+ [[[[-14.8729, -14.8729], [-14.6724, -14.6724]]]],
+ ],
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ def test_inference_mask_generation_video_one_point_propagate_in_video_directly(self):
+ raw_video = prepare_video()
+ inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
+ ann_frame_idx = 0 # the frame index we interact with
+ ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
+
+ self.processor.add_inputs_to_inference_session(
+ inference_session=inference_session,
+ frame_idx=ann_frame_idx,
+ obj_ids=ann_obj_id,
+ input_points=[[[[210, 350]]]],
+ input_labels=[[[1]]],
+ )
+ # test propagate in video frames
+ frames = []
+ for sam2_video_output in self.video_model.propagate_in_video_iterator(
+ inference_session=inference_session,
+ start_frame_idx=ann_frame_idx,
+ max_frame_num_to_track=2,
+ ):
+ video_res_masks = self.processor.post_process_masks(
+ [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0]
+ frames.append(video_res_masks)
+ frames = torch.stack(frames, dim=0)
+ self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ print(f"VIDEO_TEST2 - ACTUAL frames[:3, :, :, :2, :2]: {frames[:3, :, :, :2, :2]}")
+ torch.testing.assert_close(
+ frames[:3, :, :, :2, :2],
+ torch.tensor(
+ [
+ [[[[-28.3880, -28.3880], [-27.5260, -27.5260]]]],
+ [[[[-15.3350, -15.3350], [-15.0002, -15.0002]]]],
+ [[[[-14.8729, -14.8729], [-14.6724, -14.6724]]]],
+ ]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ def test_inference_mask_generation_video_multi_points(self):
+ raw_video = prepare_video()
+ inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
+ ann_frame_idx = 0 # the frame index we interact with
+ ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
+
+ self.processor.add_inputs_to_inference_session(
+ inference_session=inference_session,
+ frame_idx=ann_frame_idx,
+ obj_ids=ann_obj_id,
+ input_points=[[[[210, 350], [250, 220]]]],
+ input_labels=[[[1, 1]]],
+ )
+ outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
+ low_res_masks = outputs.pred_masks
+ video_res_masks = self.processor.post_process_masks(
+ [outputs.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0]
+ self.assertEqual(low_res_masks.shape, (1, 1, 256, 256))
+ self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ torch.testing.assert_close(
+ video_res_masks[0, 0, :3, :3],
+ torch.tensor(
+ [[-17.3081, -17.3081, -16.9805], [-16.8430, -16.8430, -16.6766], [-15.7986, -15.7986, -15.9941]]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ # test propagate in video frames
+ frames = []
+ for sam2_video_output in self.video_model.propagate_in_video_iterator(
+ inference_session=inference_session,
+ start_frame_idx=ann_frame_idx,
+ max_frame_num_to_track=2,
+ ):
+ video_res_masks = self.processor.post_process_masks(
+ [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0]
+ frames.append(video_res_masks)
+ frames = torch.stack(frames, dim=0)
+ self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ # higher tolerance due to errors propagating from frame to frame
+ torch.testing.assert_close(
+ frames[:3, :, :, :2, :2],
+ torch.tensor(
+ [
+ [[[[-17.3081, -17.3081], [-16.8430, -16.8430]]]],
+ [[[[-14.9302, -14.9302], [-14.8802, -14.8802]]]],
+ [[[[-14.4372, -14.4372], [-14.3697, -14.3697]]]],
+ ]
+ ).to(torch_device),
+ atol=1e-2,
+ rtol=1e-2,
+ )
+
+ def test_inference_mask_generation_video_one_bb(self):
+ raw_video = prepare_video()
+ inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
+ ann_frame_idx = 0 # the frame index we interact with
+ ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
+
+ self.processor.add_inputs_to_inference_session(
+ inference_session=inference_session,
+ frame_idx=ann_frame_idx,
+ obj_ids=ann_obj_id,
+ input_boxes=[[[300, 0, 500, 400]]],
+ )
+ outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
+ low_res_masks = outputs.pred_masks
+ video_res_masks = self.processor.post_process_masks(
+ [outputs.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0]
+ self.assertEqual(low_res_masks.shape, (1, 1, 256, 256))
+ self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ torch.testing.assert_close(
+ video_res_masks[0, 0, :3, :3],
+ torch.tensor(
+ [[-17.3245, -17.3245, -16.9231], [-16.8773, -16.8773, -16.6082], [-15.8731, -15.8731, -15.9011]]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ # test propagate in video frames
+ frames = []
+ for sam2_video_output in self.video_model.propagate_in_video_iterator(
+ inference_session=inference_session,
+ start_frame_idx=ann_frame_idx,
+ max_frame_num_to_track=2,
+ ):
+ video_res_masks = self.processor.post_process_masks(
+ [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0]
+ frames.append(video_res_masks)
+ frames = torch.stack(frames, dim=0)
+ self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ # higher tolerance due to errors propagating from frame to frame
+ torch.testing.assert_close(
+ frames[:3, :, :, :2, :2],
+ torch.tensor(
+ [
+ [[[[-17.3245, -17.3245], [-16.8773, -16.8773]]]],
+ [[[[-16.2826, -16.2826], [-15.9087, -15.9087]]]],
+ [[[[-15.8716, -15.8716], [-15.3992, -15.3992]]]],
+ ]
+ ).to(torch_device),
+ atol=1e-2,
+ rtol=1e-2,
+ )
+
+ def test_inference_mask_generation_video_one_point_one_bb(self):
+ raw_video = prepare_video()
+ inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
+ ann_frame_idx = 0 # the frame index we interact with
+ ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
+
+ self.processor.add_inputs_to_inference_session(
+ inference_session=inference_session,
+ frame_idx=ann_frame_idx,
+ obj_ids=ann_obj_id,
+ input_boxes=[[[300, 0, 500, 400]]],
+ input_points=[[[[460, 60]]]],
+ input_labels=[[[1]]],
+ )
+ outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
+ low_res_masks = outputs.pred_masks
+ video_res_masks = self.processor.post_process_masks(
+ [outputs.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0]
+ self.assertEqual(low_res_masks.shape, (1, 1, 256, 256))
+ self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ torch.testing.assert_close(
+ video_res_masks[0, 0, :3, :3],
+ torch.tensor(
+ [[-13.9780, -13.9780, -13.7824], [-13.7642, -13.7642, -13.6000], [-13.2842, -13.2842, -13.1904]]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ # test propagate in video frames
+ frames = []
+ for sam2_video_output in self.video_model.propagate_in_video_iterator(
+ inference_session=inference_session,
+ start_frame_idx=ann_frame_idx,
+ max_frame_num_to_track=2,
+ ):
+ video_res_masks = self.processor.post_process_masks(
+ [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0]
+ frames.append(video_res_masks)
+ frames = torch.stack(frames, dim=0)
+ self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ # higher tolerance due to errors propagating from frame to frame
+ torch.testing.assert_close(
+ frames[:3, :, :, :2, :2],
+ torch.tensor(
+ [
+ [[[[-13.9780, -13.9780], [-13.7642, -13.7642]]]],
+ [[[[-16.0142, -16.0142], [-15.5600, -15.5600]]]],
+ [[[[-16.7568, -16.7568], [-16.2460, -16.2460]]]],
+ ]
+ ).to(torch_device),
+ atol=1e-2,
+ rtol=1e-2,
+ )
+
+ def test_inference_mask_generation_video_multi_objects_multi_points(self):
+ raw_video = prepare_video()
+ inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
+ ann_frame_idx = 0 # the frame index we interact with
+ ann_obj_ids = [2, 3] # give a unique id to each object we interact with (it can be any integers)
+
+ self.processor.add_inputs_to_inference_session(
+ inference_session=inference_session,
+ frame_idx=ann_frame_idx,
+ obj_ids=ann_obj_ids,
+ input_points=[[[[200, 300], [230, 250], [275, 175]], [[400, 150]]]],
+ input_labels=[[[1, 1, 0], [1]]],
+ )
+ outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
+ low_res_masks = outputs.pred_masks
+ video_res_masks = self.processor.post_process_masks(
+ [outputs.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0]
+ self.assertEqual(low_res_masks.shape, (2, 1, 256, 256))
+ self.assertEqual(video_res_masks.shape, (2, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ torch.testing.assert_close(
+ video_res_masks[:, 0, :2, :2], # first object
+ torch.tensor(
+ [[[-12.6233, -12.6233], [-12.1809, -12.1809]], [[-13.4556, -13.4556], [-12.9549, -12.9549]]]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ # test propagate in video frames
+ frames = []
+ for sam2_video_output in self.video_model.propagate_in_video_iterator(
+ inference_session=inference_session,
+ start_frame_idx=ann_frame_idx,
+ max_frame_num_to_track=2,
+ ):
+ video_res_masks = self.processor.post_process_masks(
+ [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0]
+ frames.append(video_res_masks)
+ frames = torch.stack(frames, dim=0)
+ self.assertEqual(frames.shape, (3, 2, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ torch.testing.assert_close(
+ frames[:3, :, :, :2, :2],
+ torch.tensor(
+ [
+ [[[[-12.6233, -12.6233], [-12.1809, -12.1809]]], [[[-13.4556, -13.4556], [-12.9549, -12.9549]]]],
+ [[[[-12.5589, -12.5589], [-12.4450, -12.4450]]], [[[-12.2181, -12.2181], [-12.0188, -12.0188]]]],
+ [[[[-15.3170, -15.3170], [-15.0254, -15.0254]]], [[[-11.4912, -11.4912], [-11.3171, -11.3171]]]],
+ ]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ def test_inference_propagate_video_from_mask_input(self):
+ raw_video = prepare_video()
+ inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
+ ann_frame_idx = 0 # the frame index we interact with
+ ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
+
+ # get input_mask
+ self.processor.add_inputs_to_inference_session(
+ inference_session=inference_session,
+ frame_idx=ann_frame_idx,
+ obj_ids=ann_obj_id,
+ input_points=[[[[210, 350], [250, 220]]]],
+ input_labels=[[[1, 1]]],
+ )
+ sam2_video_output = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
+
+ # set mask as input
+ self.processor.add_inputs_to_inference_session(
+ inference_session=inference_session,
+ frame_idx=ann_frame_idx,
+ obj_ids=ann_obj_id,
+ input_masks=self.processor.post_process_masks(
+ [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0],
+ )
+ sam2_video_output = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
+ low_res_masks = sam2_video_output.pred_masks
+ self.assertEqual(low_res_masks.shape, (1, 1, 256, 256))
+ video_res_masks = self.processor.post_process_masks(
+ [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0]
+ self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ torch.testing.assert_close(
+ video_res_masks[0, 0, :3, :3],
+ torch.tensor(
+ [[-10.0000, -10.0000, -10.0000], [-10.0000, -10.0000, -10.0000], [-10.0000, -10.0000, -10.0000]]
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ # test propagate in video frames
+ frames = []
+ for sam2_video_output in self.video_model.propagate_in_video_iterator(
+ inference_session=inference_session,
+ start_frame_idx=ann_frame_idx,
+ max_frame_num_to_track=2,
+ ):
+ video_res_masks = self.processor.post_process_masks(
+ [sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
+ )[0]
+ frames.append(video_res_masks)
+ frames = torch.stack(frames, dim=0)
+ self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2]))
+ torch.testing.assert_close(
+ frames[:3, :, :, :2, :2],
+ torch.tensor(
+ [
+ [[[[-10.0000, -10.0000], [-10.0000, -10.0000]]]],
+ [[[[-17.4083, -17.4083], [-17.2256, -17.2256]]]],
+ [[[[-13.8533, -13.8533], [-13.7759, -13.7759]]]],
+ ],
+ ).to(torch_device),
+ atol=1e-4,
+ rtol=1e-4,
+ )
+
+ def test_inference_propagate_on_streamed_video(self):
+ raw_video = prepare_video()
+
+ inference_session = self.processor.init_video_session(inference_device=torch_device)
+ video_res_masks = []
+ max_frame_num_to_track = 3
+ for frame_idx, frame in enumerate(raw_video):
+ if frame_idx >= max_frame_num_to_track:
+ break
+ inputs = self.processor(images=frame, device=torch_device, return_tensors="pt")
+ if frame_idx == 0:
+ self.processor.add_inputs_to_inference_session(
+ inference_session,
+ frame_idx=0,
+ obj_ids=1,
+ input_points=[[[[210, 350], [250, 220]]]],
+ input_labels=[[[1, 1]]],
+ original_size=inputs.original_sizes[0],
+ )
+ sam2_video_output = self.video_model(inference_session=inference_session, frame=inputs.pixel_values[0])
+ video_res_masks.append(
+ self.processor.post_process_masks(
+ [sam2_video_output.pred_masks], inputs.original_sizes, binarize=False
+ )[0]
+ )
+
+ video_res_masks = torch.stack(video_res_masks, dim=0)
+ self.assertEqual(
+ video_res_masks.shape, (max_frame_num_to_track, 1, 1, raw_video.shape[-3], raw_video.shape[-2])
+ )
+ # higher tolerance due to errors propagating from frame to frame
+ print(f"VIDEO_TEST8 - ACTUAL video_res_masks[:3, :, :, :2, :2]: {video_res_masks[:3, :, :, :2, :2]}")
+ torch.testing.assert_close(
+ video_res_masks[:3, :, :, :2, :2],
+ torch.tensor(
+ [
+ [[[[-17.3081, -17.3081], [-16.8430, -16.8430]]]],
+ [[[[-14.9302, -14.9302], [-14.8802, -14.8802]]]],
+ [[[[-14.4372, -14.4372], [-14.3697, -14.3697]]]],
+ ]
+ ).to(torch_device),
+ atol=1e-2,
+ rtol=1e-2,
+ )
diff --git a/tests/models/sam2/test_modeling_sam2.py b/tests/models/sam2/test_modeling_sam2.py
index a19c6a13d220..dcacd3920a7a 100644
--- a/tests/models/sam2/test_modeling_sam2.py
+++ b/tests/models/sam2/test_modeling_sam2.py
@@ -558,7 +558,6 @@ def test_attention_outputs(self):
)
# Override as Sam2Model has different sub-modules
-
def test_sdpa_can_dispatch_composite_models(self):
"""
Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model.
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 0890a1abc4da..2df8d17d6fad 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -140,7 +140,9 @@
"BarkCausalModel", # Building part of bigger (tested) model.
"BarkModel", # Does not have a forward signature - generation tested with integration tests.
"Sam2HieraDetModel", # Building part of bigger (tested) model.
- "Sam2VideoModel", # inherit from Sam2Model (tested).
+ "Sam2VideoModel", # Partly tested in Sam2Model, not regular model.
+ "EdgeTamVisionModel", # Building part of bigger (tested) model.
+ "EdgeTamVideoModel", # Partly tested in EdgeTamModel, not regular model.
"SeamlessM4TTextToUnitModel", # Building part of bigger (tested) model.
"SeamlessM4TCodeHifiGan", # Building part of bigger (tested) model.
"SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model.
@@ -208,6 +210,7 @@
"models/shieldgemma2/test_modeling_shieldgemma2.py",
"models/llama4/test_modeling_llama4.py",
"models/sam2_video/test_modeling_sam2_video.py",
+ "models/edgetam_video/test_modeling_edgetam_video.py",
]
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
@@ -256,6 +259,8 @@
"SamModel",
"Sam2Model",
"Sam2VideoModel",
+ "EdgeTamModel",
+ "EdgeTamVideoModel",
"SamHQModel",
"DPTForDepthEstimation",
"DecisionTransformerGPT2Model",