Skip to content

Commit

Permalink
TF port of the Segment Anything Model (SAM) (#22970)
Browse files Browse the repository at this point in the history
* First commit

* Add auto-translation with GPT-4

* make fixup

* Add a functional layernorm for TF

* Add all the auxiliary imports etc.

* Add the extra processor and tests

* rebase to main

* Add all the needed fixes to the GPT code

* make fixup

* Make convolutions channels-last so they run on CPU

* make fixup

* Fix final issues

* Fix other models affected by test change

* Clarify comment on the sparse_prompt_embeddings check

* Refactor functional_layernorm, use shape_list in place of .shape in some places

* Remove deprecated torch-alike code

* Update tests/models/sam/test_modeling_tf_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/sam/test_modeling_tf_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Refactor processor with common methods and separated private methods

* make fixup

* Quietly delete the file that didn't do anything (sorry Sylvain)

* Refactor the processor tests into one file

* make fixup

* Clean up some unnecessary indirection

* Fix TF mask postprocessing

* Add more processor equivalence tests

* Refactor generate_crop_boxes to use framework-neutral np code

* Make the serving output correctly conditional

* Fix error message line length

* Use dict keys rather than indices internally in both TF and PT SAM call/forward

* Return dicts internally in the call/forward methods

* Revert changes to common tests and just override check_pt_tf_outputs

* Revert changes to other model tests

* Clarify comments for functional layernorm

* Add missing transpose from PT code

* Removed unused copied from in PT code

* Remove overrides for tests that don't exist in TF

* Fix transpose and update tests for PT and TF to check pred_masks

* Add training flag

* Update tests to use TF checkpoints

* Update index.mdx

* Add missing cross-test decorator

* Remove optional extra asterisks

* Revert return_dict changes in PT code

* Update src/transformers/models/sam/modeling_tf_sam.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove None return annotations on init methods

* Update tests/models/sam/test_processor_sam.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix input_boxes shapes

* make fixup

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
3 people authored May 19, 2023
1 parent 8aa8513 commit 1c460a5
Show file tree
Hide file tree
Showing 14 changed files with 2,940 additions and 44 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
| RoCBert | | | | | |
| RoFormer | | | | | |
| RWKV | | | | | |
| SAM | | | | | |
| SAM | | | | | |
| SegFormer | | | | | |
| SEW | | | | | |
| SEW-D | | | | | |
Expand Down
6 changes: 6 additions & 0 deletions docs/source/en/model_doc/sam.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,9 @@ Resources:

[[autodoc]] SamModel
- forward


## TFSamModel

[[autodoc]] TFSamModel
- call
12 changes: 12 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3406,6 +3406,13 @@
"TFRoFormerPreTrainedModel",
]
)
_import_structure["models.sam"].extend(
[
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
)
_import_structure["models.segformer"].extend(
[
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -6657,6 +6664,11 @@
TFRoFormerModel,
TFRoFormerPreTrainedModel,
)
from .models.sam import (
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSamModel,
TFSamPreTrainedModel,
)
from .models.segformer import (
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSegformerDecodeHead,
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
("roberta", "TFRobertaModel"),
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"),
("sam", "TFSamModel"),
("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"),
("swin", "TFSwinModel"),
Expand Down Expand Up @@ -426,6 +427,11 @@
("mobilebert", "TFMobileBertForNextSentencePrediction"),
]
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam", "TFSamModel"),
]
)

TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
Expand Down Expand Up @@ -476,6 +482,14 @@
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)

TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
)


class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING


class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING
Expand Down
27 changes: 26 additions & 1 deletion src/transformers/models/sam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)


_import_structure = {
Expand All @@ -39,6 +45,17 @@
"SamModel",
"SamPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_sam"] = [
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -66,6 +83,14 @@
else:
from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
Expand Down
Loading

0 comments on commit 1c460a5

Please sign in to comment.