Skip to content

Commit e4a60b4

Browse files
geetu040qubvel
authored andcommitted
Create and Expose SamVisionModel as public for better accessibility (huggingface#36493)
* move encoder below * auto modeling * write SamVisionTester * fix vision attention shape * fix SamVisionTest * minor changes to SamVisionTest * Revert "fix vision attention shape" This reverts commit d2a4083. * fix attention output shape in new tests * remove encoder examples * run modular on got_ocr2 * code formatting * fix got_ocr2 * ruff fixes * code quality * add sam_vision in auto modeling and auto configuration * remove composite test * updated index.md * add TFSamVisionEncoder to __init__ * fix public TFSamVisionEncoder * remove outdated todo comment * set test_torch_exportable Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * rename: VisionEncoder -> VisionModel * bring back original SamVisionEncoder * rename back: VisionEncoderOutput -> VisionModelOutput * undo changes in SamModelTester * reuse SamVisionEncoder in SamVisionModel --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
1 parent a7509b8 commit e4a60b4

File tree

12 files changed

+612
-7
lines changed

12 files changed

+612
-7
lines changed

docs/source/en/model_doc/sam.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,24 @@ alt="drawing" width="900"/>
149149
[[autodoc]] SamImageProcessor
150150

151151

152+
## SamVisionModel
153+
154+
[[autodoc]] SamVisionModel
155+
- forward
156+
157+
152158
## SamModel
153159

154160
[[autodoc]] SamModel
155161
- forward
156162

157163

164+
## TFSamVisionModel
165+
166+
[[autodoc]] TFSamVisionModel
167+
- call
168+
169+
158170
## TFSamModel
159171

160172
[[autodoc]] TFSamModel

src/transformers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3611,6 +3611,7 @@
36113611
[
36123612
"SamModel",
36133613
"SamPreTrainedModel",
3614+
"SamVisionModel",
36143615
]
36153616
)
36163617
_import_structure["models.seamless_m4t"].extend(
@@ -4779,6 +4780,7 @@
47794780
[
47804781
"TFSamModel",
47814782
"TFSamPreTrainedModel",
4783+
"TFSamVisionModel",
47824784
]
47834785
)
47844786
_import_structure["models.segformer"].extend(
@@ -8473,6 +8475,7 @@
84738475
from .models.sam import (
84748476
SamModel,
84758477
SamPreTrainedModel,
8478+
SamVisionModel,
84768479
)
84778480
from .models.seamless_m4t import (
84788481
SeamlessM4TCodeHifiGan,
@@ -9414,6 +9417,7 @@
94149417
from .models.sam import (
94159418
TFSamModel,
94169419
TFSamPreTrainedModel,
9420+
TFSamVisionModel,
94179421
)
94189422
from .models.segformer import (
94199423
TFSegformerDecodeHead,

src/transformers/models/auto/configuration_auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@
274274
("rt_detr_v2", "RTDetrV2Config"),
275275
("rwkv", "RwkvConfig"),
276276
("sam", "SamConfig"),
277+
("sam_vision_model", "SamVisionConfig"),
277278
("seamless_m4t", "SeamlessM4TConfig"),
278279
("seamless_m4t_v2", "SeamlessM4Tv2Config"),
279280
("segformer", "SegformerConfig"),
@@ -632,6 +633,7 @@
632633
("rt_detr_v2", "RT-DETRv2"),
633634
("rwkv", "RWKV"),
634635
("sam", "SAM"),
636+
("sam_vision_model", "SamVisionModel"),
635637
("seamless_m4t", "SeamlessM4T"),
636638
("seamless_m4t_v2", "SeamlessM4Tv2"),
637639
("segformer", "SegFormer"),
@@ -775,6 +777,7 @@
775777
("chinese_clip_vision_model", "chinese_clip"),
776778
("rt_detr_resnet", "rt_detr"),
777779
("granitevision", "llava_next"),
780+
("sam_vision_model", "sam"),
778781
]
779782
)
780783

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@
250250
("rt_detr_v2", "RTDetrV2Model"),
251251
("rwkv", "RwkvModel"),
252252
("sam", "SamModel"),
253+
("sam_vision_model", "SamVisionModel"),
253254
("seamless_m4t", "SeamlessM4TModel"),
254255
("seamless_m4t_v2", "SeamlessM4Tv2Model"),
255256
("segformer", "SegformerModel"),

src/transformers/models/auto/modeling_tf_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
8181
("roformer", "TFRoFormerModel"),
8282
("sam", "TFSamModel"),
83+
("sam_vision_model", "TFSamVisionModel"),
8384
("segformer", "TFSegformerModel"),
8485
("speech_to_text", "TFSpeech2TextModel"),
8586
("swiftformer", "TFSwiftFormerModel"),

src/transformers/models/sam/configuration_sam.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,27 @@ class SamVisionConfig(PretrainedConfig):
183183
mlp_dim (`int`, *optional*):
184184
The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio *
185185
hidden_size`.
186-
"""
186+
187+
Example:
188+
189+
```python
190+
>>> from transformers import (
191+
... SamVisionConfig,
192+
... SamVisionModel,
193+
... )
194+
195+
>>> # Initializing a SamVisionConfig with `"facebook/sam-vit-huge"` style configuration
196+
>>> configuration = SamVisionConfig()
197+
198+
>>> # Initializing a SamVisionModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
199+
>>> model = SamVisionModel(configuration)
200+
201+
>>> # Accessing the model configuration
202+
>>> configuration = model.config
203+
```"""
187204

188205
base_config_key = "vision_config"
206+
model_type = "sam_vision_model"
189207

190208
def __init__(
191209
self,

src/transformers/models/sam/modeling_sam.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@
2727
from ...activations import ACT2FN
2828
from ...modeling_outputs import BaseModelOutput
2929
from ...modeling_utils import PreTrainedModel
30-
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
30+
from ...utils import (
31+
ModelOutput,
32+
add_start_docstrings,
33+
add_start_docstrings_to_model_forward,
34+
logging,
35+
replace_return_docstrings,
36+
)
3137
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
3238

3339

@@ -1280,6 +1286,61 @@ def _init_weights(self, module):
12801286
"""
12811287

12821288

1289+
SAM_VISION_INPUTS_DOCSTRING = r"""
1290+
Args:
1291+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1292+
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
1293+
details.
1294+
output_attentions (`bool`, *optional*):
1295+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1296+
tensors for more detail.
1297+
output_hidden_states (`bool`, *optional*):
1298+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1299+
more detail.
1300+
return_dict (`bool`, *optional*):
1301+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1302+
"""
1303+
1304+
1305+
@add_start_docstrings(
1306+
"""The vision model from Sam without any head or projection on top.""",
1307+
SAM_START_DOCSTRING,
1308+
)
1309+
class SamVisionModel(SamPreTrainedModel):
1310+
config_class = SamVisionConfig
1311+
main_input_name = "pixel_values"
1312+
1313+
def __init__(self, config: SamVisionConfig):
1314+
super().__init__(config)
1315+
self.vision_encoder = SamVisionEncoder(config)
1316+
1317+
# Initialize weights and apply final processing
1318+
self.post_init()
1319+
1320+
def get_input_embeddings(self) -> nn.Module:
1321+
return self.vision_encoder.patch_embed
1322+
1323+
@add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
1324+
@replace_return_docstrings(output_type=SamVisionEncoderOutput, config_class=SamVisionConfig)
1325+
def forward(
1326+
self,
1327+
pixel_values: Optional[torch.FloatTensor] = None,
1328+
output_attentions: Optional[bool] = None,
1329+
output_hidden_states: Optional[bool] = None,
1330+
return_dict: Optional[bool] = None,
1331+
) -> Union[Tuple, SamVisionEncoderOutput]:
1332+
r"""
1333+
Returns:
1334+
1335+
"""
1336+
return self.vision_encoder(
1337+
pixel_values,
1338+
output_attentions=output_attentions,
1339+
output_hidden_states=output_hidden_states,
1340+
return_dict=return_dict,
1341+
)
1342+
1343+
12831344
@add_start_docstrings(
12841345
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
12851346
" optional 2D location and bounding boxes.",
@@ -1522,4 +1583,4 @@ def forward(
15221583
)
15231584

15241585

1525-
__all__ = ["SamModel", "SamPreTrainedModel"]
1586+
__all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"]

src/transformers/models/sam/modeling_tf_sam.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@
3030
from ...modeling_tf_outputs import TFBaseModelOutput
3131
from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs
3232
from ...tf_utils import flatten, functional_layernorm
33-
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
33+
from ...utils import (
34+
ModelOutput,
35+
add_start_docstrings,
36+
add_start_docstrings_to_model_forward,
37+
logging,
38+
replace_return_docstrings,
39+
)
3440
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
3541

3642

@@ -1400,6 +1406,70 @@ class TFSamPreTrainedModel(TFPreTrainedModel):
14001406
"""
14011407

14021408

1409+
SAM_VISION_INPUTS_DOCSTRING = r"""
1410+
Args:
1411+
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
1412+
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
1413+
details.
1414+
output_attentions (`bool`, *optional*):
1415+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1416+
tensors for more detail.
1417+
output_hidden_states (`bool`, *optional*):
1418+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1419+
more detail.
1420+
return_dict (`bool`, *optional*):
1421+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1422+
"""
1423+
1424+
1425+
@add_start_docstrings(
1426+
"""The vision model from Sam without any head or projection on top.""",
1427+
SAM_START_DOCSTRING,
1428+
)
1429+
class TFSamVisionModel(TFSamPreTrainedModel):
1430+
config_class = SamVisionConfig
1431+
main_input_name = "pixel_values"
1432+
1433+
def __init__(self, config: SamVisionConfig, **kwargs):
1434+
super().__init__(config, **kwargs)
1435+
self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder")
1436+
1437+
def build(self, input_shape=None):
1438+
if self.built:
1439+
return
1440+
self.built = True
1441+
if getattr(self, "vision_encoder", None) is not None:
1442+
with tf.name_scope(self.vision_encoder.name):
1443+
self.vision_encoder.build(None)
1444+
1445+
def get_input_embeddings(self):
1446+
return self.vision_encoder.patch_embed
1447+
1448+
@unpack_inputs
1449+
@add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
1450+
@replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig)
1451+
def call(
1452+
self,
1453+
pixel_values: TFModelInputType | None = None,
1454+
output_attentions: bool | None = None,
1455+
output_hidden_states: bool | None = None,
1456+
return_dict: bool | None = None,
1457+
training: bool = False,
1458+
**kwargs,
1459+
) -> TFSamVisionEncoderOutput | Tuple[tf.Tensor]:
1460+
r"""
1461+
Returns:
1462+
1463+
"""
1464+
return self.vision_encoder(
1465+
pixel_values,
1466+
output_attentions=output_attentions,
1467+
output_hidden_states=output_hidden_states,
1468+
return_dict=return_dict,
1469+
training=training,
1470+
)
1471+
1472+
14031473
@add_start_docstrings(
14041474
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
14051475
" optional 2D location and bounding boxes.",
@@ -1653,4 +1723,4 @@ def build(self, input_shape=None):
16531723
self.mask_decoder.build(None)
16541724

16551725

1656-
__all__ = ["TFSamModel", "TFSamPreTrainedModel"]
1726+
__all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"]

src/transformers/utils/dummy_pt_objects.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8906,6 +8906,13 @@ def __init__(self, *args, **kwargs):
89068906
requires_backends(self, ["torch"])
89078907

89088908

8909+
class SamVisionModel(metaclass=DummyObject):
8910+
_backends = ["torch"]
8911+
8912+
def __init__(self, *args, **kwargs):
8913+
requires_backends(self, ["torch"])
8914+
8915+
89098916
class SeamlessM4TCodeHifiGan(metaclass=DummyObject):
89108917
_backends = ["torch"]
89118918

src/transformers/utils/dummy_tf_objects.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,6 +2375,13 @@ def __init__(self, *args, **kwargs):
23752375
requires_backends(self, ["tf"])
23762376

23772377

2378+
class TFSamVisionModel(metaclass=DummyObject):
2379+
_backends = ["tf"]
2380+
2381+
def __init__(self, *args, **kwargs):
2382+
requires_backends(self, ["tf"])
2383+
2384+
23782385
class TFSegformerDecodeHead(metaclass=DummyObject):
23792386
_backends = ["tf"]
23802387

0 commit comments

Comments
 (0)