Skip to content

Commit 0710e9b

Browse files
geetu040qubvel
andauthored
Create and Expose SamVisionModel as public for better accessibility (#36493)
* move encoder below * auto modeling * write SamVisionTester * fix vision attention shape * fix SamVisionTest * minor changes to SamVisionTest * Revert "fix vision attention shape" This reverts commit d2a4083. * fix attention output shape in new tests * remove encoder examples * run modular on got_ocr2 * code formatting * fix got_ocr2 * ruff fixes * code quality * add sam_vision in auto modeling and auto configuration * remove composite test * updated index.md * add TFSamVisionEncoder to __init__ * fix public TFSamVisionEncoder * remove outdated todo comment * set test_torch_exportable Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * rename: VisionEncoder -> VisionModel * bring back original SamVisionEncoder * rename back: VisionEncoderOutput -> VisionModelOutput * undo changes in SamModelTester * reuse SamVisionEncoder in SamVisionModel --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
1 parent f99c279 commit 0710e9b

File tree

12 files changed

+612
-7
lines changed

12 files changed

+612
-7
lines changed

docs/source/en/model_doc/sam.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,24 @@ alt="drawing" width="900"/>
149149
[[autodoc]] SamImageProcessor
150150

151151

152+
## SamVisionModel
153+
154+
[[autodoc]] SamVisionModel
155+
- forward
156+
157+
152158
## SamModel
153159

154160
[[autodoc]] SamModel
155161
- forward
156162

157163

164+
## TFSamVisionModel
165+
166+
[[autodoc]] TFSamVisionModel
167+
- call
168+
169+
158170
## TFSamModel
159171

160172
[[autodoc]] TFSamModel

src/transformers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3589,6 +3589,7 @@
35893589
[
35903590
"SamModel",
35913591
"SamPreTrainedModel",
3592+
"SamVisionModel",
35923593
]
35933594
)
35943595
_import_structure["models.seamless_m4t"].extend(
@@ -4757,6 +4758,7 @@
47574758
[
47584759
"TFSamModel",
47594760
"TFSamPreTrainedModel",
4761+
"TFSamVisionModel",
47604762
]
47614763
)
47624764
_import_structure["models.segformer"].extend(
@@ -8431,6 +8433,7 @@
84318433
from .models.sam import (
84328434
SamModel,
84338435
SamPreTrainedModel,
8436+
SamVisionModel,
84348437
)
84358438
from .models.seamless_m4t import (
84368439
SeamlessM4TCodeHifiGan,
@@ -9372,6 +9375,7 @@
93729375
from .models.sam import (
93739376
TFSamModel,
93749377
TFSamPreTrainedModel,
9378+
TFSamVisionModel,
93759379
)
93769380
from .models.segformer import (
93779381
TFSegformerDecodeHead,

src/transformers/models/auto/configuration_auto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@
273273
("rt_detr_v2", "RTDetrV2Config"),
274274
("rwkv", "RwkvConfig"),
275275
("sam", "SamConfig"),
276+
("sam_vision_model", "SamVisionConfig"),
276277
("seamless_m4t", "SeamlessM4TConfig"),
277278
("seamless_m4t_v2", "SeamlessM4Tv2Config"),
278279
("segformer", "SegformerConfig"),
@@ -630,6 +631,7 @@
630631
("rt_detr_v2", "RT-DETRv2"),
631632
("rwkv", "RWKV"),
632633
("sam", "SAM"),
634+
("sam_vision_model", "SamVisionModel"),
633635
("seamless_m4t", "SeamlessM4T"),
634636
("seamless_m4t_v2", "SeamlessM4Tv2"),
635637
("segformer", "SegFormer"),
@@ -773,6 +775,7 @@
773775
("chinese_clip_vision_model", "chinese_clip"),
774776
("rt_detr_resnet", "rt_detr"),
775777
("granitevision", "llava_next"),
778+
("sam_vision_model", "sam"),
776779
]
777780
)
778781

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@
249249
("rt_detr_v2", "RTDetrV2Model"),
250250
("rwkv", "RwkvModel"),
251251
("sam", "SamModel"),
252+
("sam_vision_model", "SamVisionModel"),
252253
("seamless_m4t", "SeamlessM4TModel"),
253254
("seamless_m4t_v2", "SeamlessM4Tv2Model"),
254255
("segformer", "SegformerModel"),

src/transformers/models/auto/modeling_tf_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
8181
("roformer", "TFRoFormerModel"),
8282
("sam", "TFSamModel"),
83+
("sam_vision_model", "TFSamVisionModel"),
8384
("segformer", "TFSegformerModel"),
8485
("speech_to_text", "TFSpeech2TextModel"),
8586
("swiftformer", "TFSwiftFormerModel"),

src/transformers/models/sam/configuration_sam.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,27 @@ class SamVisionConfig(PretrainedConfig):
183183
mlp_dim (`int`, *optional*):
184184
The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio *
185185
hidden_size`.
186-
"""
186+
187+
Example:
188+
189+
```python
190+
>>> from transformers import (
191+
... SamVisionConfig,
192+
... SamVisionModel,
193+
... )
194+
195+
>>> # Initializing a SamVisionConfig with `"facebook/sam-vit-huge"` style configuration
196+
>>> configuration = SamVisionConfig()
197+
198+
>>> # Initializing a SamVisionModel (with random weights) from the `"facebook/sam-vit-huge"` style configuration
199+
>>> model = SamVisionModel(configuration)
200+
201+
>>> # Accessing the model configuration
202+
>>> configuration = model.config
203+
```"""
187204

188205
base_config_key = "vision_config"
206+
model_type = "sam_vision_model"
189207

190208
def __init__(
191209
self,

src/transformers/models/sam/modeling_sam.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@
2727
from ...activations import ACT2FN
2828
from ...modeling_outputs import BaseModelOutput
2929
from ...modeling_utils import PreTrainedModel
30-
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
30+
from ...utils import (
31+
ModelOutput,
32+
add_start_docstrings,
33+
add_start_docstrings_to_model_forward,
34+
logging,
35+
replace_return_docstrings,
36+
)
3137
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
3238

3339

@@ -1280,6 +1286,61 @@ def _init_weights(self, module):
12801286
"""
12811287

12821288

1289+
SAM_VISION_INPUTS_DOCSTRING = r"""
1290+
Args:
1291+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1292+
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
1293+
details.
1294+
output_attentions (`bool`, *optional*):
1295+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1296+
tensors for more detail.
1297+
output_hidden_states (`bool`, *optional*):
1298+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1299+
more detail.
1300+
return_dict (`bool`, *optional*):
1301+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1302+
"""
1303+
1304+
1305+
@add_start_docstrings(
1306+
"""The vision model from Sam without any head or projection on top.""",
1307+
SAM_START_DOCSTRING,
1308+
)
1309+
class SamVisionModel(SamPreTrainedModel):
1310+
config_class = SamVisionConfig
1311+
main_input_name = "pixel_values"
1312+
1313+
def __init__(self, config: SamVisionConfig):
1314+
super().__init__(config)
1315+
self.vision_encoder = SamVisionEncoder(config)
1316+
1317+
# Initialize weights and apply final processing
1318+
self.post_init()
1319+
1320+
def get_input_embeddings(self) -> nn.Module:
1321+
return self.vision_encoder.patch_embed
1322+
1323+
@add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
1324+
@replace_return_docstrings(output_type=SamVisionEncoderOutput, config_class=SamVisionConfig)
1325+
def forward(
1326+
self,
1327+
pixel_values: Optional[torch.FloatTensor] = None,
1328+
output_attentions: Optional[bool] = None,
1329+
output_hidden_states: Optional[bool] = None,
1330+
return_dict: Optional[bool] = None,
1331+
) -> Union[Tuple, SamVisionEncoderOutput]:
1332+
r"""
1333+
Returns:
1334+
1335+
"""
1336+
return self.vision_encoder(
1337+
pixel_values,
1338+
output_attentions=output_attentions,
1339+
output_hidden_states=output_hidden_states,
1340+
return_dict=return_dict,
1341+
)
1342+
1343+
12831344
@add_start_docstrings(
12841345
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
12851346
" optional 2D location and bounding boxes.",
@@ -1522,4 +1583,4 @@ def forward(
15221583
)
15231584

15241585

1525-
__all__ = ["SamModel", "SamPreTrainedModel"]
1586+
__all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"]

src/transformers/models/sam/modeling_tf_sam.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@
3030
from ...modeling_tf_outputs import TFBaseModelOutput
3131
from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs
3232
from ...tf_utils import flatten, functional_layernorm
33-
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
33+
from ...utils import (
34+
ModelOutput,
35+
add_start_docstrings,
36+
add_start_docstrings_to_model_forward,
37+
logging,
38+
replace_return_docstrings,
39+
)
3440
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
3541

3642

@@ -1400,6 +1406,70 @@ class TFSamPreTrainedModel(TFPreTrainedModel):
14001406
"""
14011407

14021408

1409+
SAM_VISION_INPUTS_DOCSTRING = r"""
1410+
Args:
1411+
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
1412+
Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
1413+
details.
1414+
output_attentions (`bool`, *optional*):
1415+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1416+
tensors for more detail.
1417+
output_hidden_states (`bool`, *optional*):
1418+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1419+
more detail.
1420+
return_dict (`bool`, *optional*):
1421+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1422+
"""
1423+
1424+
1425+
@add_start_docstrings(
1426+
"""The vision model from Sam without any head or projection on top.""",
1427+
SAM_START_DOCSTRING,
1428+
)
1429+
class TFSamVisionModel(TFSamPreTrainedModel):
1430+
config_class = SamVisionConfig
1431+
main_input_name = "pixel_values"
1432+
1433+
def __init__(self, config: SamVisionConfig, **kwargs):
1434+
super().__init__(config, **kwargs)
1435+
self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder")
1436+
1437+
def build(self, input_shape=None):
1438+
if self.built:
1439+
return
1440+
self.built = True
1441+
if getattr(self, "vision_encoder", None) is not None:
1442+
with tf.name_scope(self.vision_encoder.name):
1443+
self.vision_encoder.build(None)
1444+
1445+
def get_input_embeddings(self):
1446+
return self.vision_encoder.patch_embed
1447+
1448+
@unpack_inputs
1449+
@add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING)
1450+
@replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig)
1451+
def call(
1452+
self,
1453+
pixel_values: TFModelInputType | None = None,
1454+
output_attentions: bool | None = None,
1455+
output_hidden_states: bool | None = None,
1456+
return_dict: bool | None = None,
1457+
training: bool = False,
1458+
**kwargs,
1459+
) -> TFSamVisionEncoderOutput | Tuple[tf.Tensor]:
1460+
r"""
1461+
Returns:
1462+
1463+
"""
1464+
return self.vision_encoder(
1465+
pixel_values,
1466+
output_attentions=output_attentions,
1467+
output_hidden_states=output_hidden_states,
1468+
return_dict=return_dict,
1469+
training=training,
1470+
)
1471+
1472+
14031473
@add_start_docstrings(
14041474
"Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
14051475
" optional 2D location and bounding boxes.",
@@ -1653,4 +1723,4 @@ def build(self, input_shape=None):
16531723
self.mask_decoder.build(None)
16541724

16551725

1656-
__all__ = ["TFSamModel", "TFSamPreTrainedModel"]
1726+
__all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"]

src/transformers/utils/dummy_pt_objects.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8836,6 +8836,13 @@ def __init__(self, *args, **kwargs):
88368836
requires_backends(self, ["torch"])
88378837

88388838

8839+
class SamVisionModel(metaclass=DummyObject):
8840+
_backends = ["torch"]
8841+
8842+
def __init__(self, *args, **kwargs):
8843+
requires_backends(self, ["torch"])
8844+
8845+
88398846
class SeamlessM4TCodeHifiGan(metaclass=DummyObject):
88408847
_backends = ["torch"]
88418848

src/transformers/utils/dummy_tf_objects.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,6 +2375,13 @@ def __init__(self, *args, **kwargs):
23752375
requires_backends(self, ["tf"])
23762376

23772377

2378+
class TFSamVisionModel(metaclass=DummyObject):
2379+
_backends = ["tf"]
2380+
2381+
def __init__(self, *args, **kwargs):
2382+
requires_backends(self, ["tf"])
2383+
2384+
23782385
class TFSegformerDecodeHead(metaclass=DummyObject):
23792386
_backends = ["tf"]
23802387

0 commit comments

Comments
 (0)