|
30 | 30 | from ...modeling_tf_outputs import TFBaseModelOutput |
31 | 31 | from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs |
32 | 32 | from ...tf_utils import flatten, functional_layernorm |
33 | | -from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging |
| 33 | +from ...utils import ( |
| 34 | + ModelOutput, |
| 35 | + add_start_docstrings, |
| 36 | + add_start_docstrings_to_model_forward, |
| 37 | + logging, |
| 38 | + replace_return_docstrings, |
| 39 | +) |
34 | 40 | from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig |
35 | 41 |
|
36 | 42 |
|
@@ -1400,6 +1406,70 @@ class TFSamPreTrainedModel(TFPreTrainedModel): |
1400 | 1406 | """ |
1401 | 1407 |
|
1402 | 1408 |
|
| 1409 | +SAM_VISION_INPUTS_DOCSTRING = r""" |
| 1410 | + Args: |
| 1411 | + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): |
| 1412 | + Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for |
| 1413 | + details. |
| 1414 | + output_attentions (`bool`, *optional*): |
| 1415 | + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| 1416 | + tensors for more detail. |
| 1417 | + output_hidden_states (`bool`, *optional*): |
| 1418 | + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| 1419 | + more detail. |
| 1420 | + return_dict (`bool`, *optional*): |
| 1421 | + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| 1422 | +""" |
| 1423 | + |
| 1424 | + |
| 1425 | +@add_start_docstrings( |
| 1426 | + """The vision model from Sam without any head or projection on top.""", |
| 1427 | + SAM_START_DOCSTRING, |
| 1428 | +) |
| 1429 | +class TFSamVisionModel(TFSamPreTrainedModel): |
| 1430 | + config_class = SamVisionConfig |
| 1431 | + main_input_name = "pixel_values" |
| 1432 | + |
| 1433 | + def __init__(self, config: SamVisionConfig, **kwargs): |
| 1434 | + super().__init__(config, **kwargs) |
| 1435 | + self.vision_encoder = TFSamVisionEncoder(config, name="vision_encoder") |
| 1436 | + |
| 1437 | + def build(self, input_shape=None): |
| 1438 | + if self.built: |
| 1439 | + return |
| 1440 | + self.built = True |
| 1441 | + if getattr(self, "vision_encoder", None) is not None: |
| 1442 | + with tf.name_scope(self.vision_encoder.name): |
| 1443 | + self.vision_encoder.build(None) |
| 1444 | + |
| 1445 | + def get_input_embeddings(self): |
| 1446 | + return self.vision_encoder.patch_embed |
| 1447 | + |
| 1448 | + @unpack_inputs |
| 1449 | + @add_start_docstrings_to_model_forward(SAM_VISION_INPUTS_DOCSTRING) |
| 1450 | + @replace_return_docstrings(output_type=TFSamVisionEncoderOutput, config_class=SamVisionConfig) |
| 1451 | + def call( |
| 1452 | + self, |
| 1453 | + pixel_values: TFModelInputType | None = None, |
| 1454 | + output_attentions: bool | None = None, |
| 1455 | + output_hidden_states: bool | None = None, |
| 1456 | + return_dict: bool | None = None, |
| 1457 | + training: bool = False, |
| 1458 | + **kwargs, |
| 1459 | + ) -> TFSamVisionEncoderOutput | Tuple[tf.Tensor]: |
| 1460 | + r""" |
| 1461 | + Returns: |
| 1462 | +
|
| 1463 | + """ |
| 1464 | + return self.vision_encoder( |
| 1465 | + pixel_values, |
| 1466 | + output_attentions=output_attentions, |
| 1467 | + output_hidden_states=output_hidden_states, |
| 1468 | + return_dict=return_dict, |
| 1469 | + training=training, |
| 1470 | + ) |
| 1471 | + |
| 1472 | + |
1403 | 1473 | @add_start_docstrings( |
1404 | 1474 | "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ", |
1405 | 1475 | " optional 2D location and bounding boxes.", |
@@ -1653,4 +1723,4 @@ def build(self, input_shape=None): |
1653 | 1723 | self.mask_decoder.build(None) |
1654 | 1724 |
|
1655 | 1725 |
|
1656 | | -__all__ = ["TFSamModel", "TFSamPreTrainedModel"] |
| 1726 | +__all__ = ["TFSamVisionModel", "TFSamModel", "TFSamPreTrainedModel"] |
0 commit comments