Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF port of the Segment Anything Model (SAM) #22970

Merged
merged 49 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
cd8df3a
First commit
Rocketknight1 Apr 22, 2023
eb103df
Add auto-translation with GPT-4
Rocketknight1 Apr 24, 2023
0a611f1
make fixup
Rocketknight1 Apr 24, 2023
5067b60
Add a functional layernorm for TF
Rocketknight1 Apr 24, 2023
75eb390
Add all the auxiliary imports etc.
Rocketknight1 Apr 25, 2023
6b38c83
Add the extra processor and tests
Rocketknight1 Apr 26, 2023
ebc235c
rebase to main
Rocketknight1 Apr 26, 2023
b969a9d
Add all the needed fixes to the GPT code
Rocketknight1 Apr 28, 2023
d3b1392
make fixup
Rocketknight1 Apr 28, 2023
9c3066b
Make convolutions channels-last so they run on CPU
Rocketknight1 May 2, 2023
6fe6674
make fixup
Rocketknight1 May 2, 2023
d6dec9a
Fix final issues
Rocketknight1 May 2, 2023
dba0920
Fix other models affected by test change
Rocketknight1 May 2, 2023
989dd3f
Clarify comment on the sparse_prompt_embeddings check
Rocketknight1 May 3, 2023
f842d43
Refactor functional_layernorm, use shape_list in place of .shape in s…
Rocketknight1 May 3, 2023
d6653e7
Remove deprecated torch-alike code
Rocketknight1 May 3, 2023
e872394
Update tests/models/sam/test_modeling_tf_sam.py
Rocketknight1 May 3, 2023
25197cf
Update tests/models/sam/test_modeling_tf_sam.py
Rocketknight1 May 3, 2023
5aabb16
Refactor processor with common methods and separated private methods
Rocketknight1 May 3, 2023
d5e1fee
make fixup
Rocketknight1 May 3, 2023
4650fcf
Quietly delete the file that didn't do anything (sorry Sylvain)
Rocketknight1 May 3, 2023
b72bfc1
Refactor the processor tests into one file
Rocketknight1 May 3, 2023
fc5136a
make fixup
Rocketknight1 May 3, 2023
2e5b4e5
Clean up some unnecessary indirection
Rocketknight1 May 4, 2023
b1cfcdf
Fix TF mask postprocessing
Rocketknight1 May 4, 2023
f7549ba
Add more processor equivalence tests
Rocketknight1 May 4, 2023
7945a2d
Refactor generate_crop_boxes to use framework-neutral np code
Rocketknight1 May 4, 2023
8305ab4
Make the serving output correctly conditional
Rocketknight1 May 4, 2023
f9de054
Fix error message line length
Rocketknight1 May 4, 2023
63d1b68
Use dict keys rather than indices internally in both TF and PT SAM ca…
Rocketknight1 May 4, 2023
faf5cb0
Return dicts internally in the call/forward methods
Rocketknight1 May 4, 2023
ce669d4
Revert changes to common tests and just override check_pt_tf_outputs
Rocketknight1 May 4, 2023
0ff4dc7
Revert changes to other model tests
Rocketknight1 May 4, 2023
34bca0f
Clarify comments for functional layernorm
Rocketknight1 May 5, 2023
74f3291
Add missing transpose from PT code
Rocketknight1 May 9, 2023
fcfef1f
Removed unused copied from in PT code
Rocketknight1 May 10, 2023
392486d
Remove overrides for tests that don't exist in TF
Rocketknight1 May 10, 2023
b7f9dd4
Fix transpose and update tests for PT and TF to check pred_masks
Rocketknight1 May 11, 2023
4f966c8
Add training flag
Rocketknight1 May 11, 2023
f29b109
Update tests to use TF checkpoints
Rocketknight1 May 11, 2023
792f376
Update index.mdx
Rocketknight1 May 11, 2023
7249cb5
Add missing cross-test decorator
Rocketknight1 May 12, 2023
28dac3e
Remove optional extra asterisks
Rocketknight1 May 16, 2023
bbc6886
Revert return_dict changes in PT code
Rocketknight1 May 16, 2023
79d2b81
Update src/transformers/models/sam/modeling_tf_sam.py
Rocketknight1 May 16, 2023
3d59612
Remove None return annotations on init methods
Rocketknight1 May 16, 2023
ee4057f
Update tests/models/sam/test_processor_sam.py
Rocketknight1 May 16, 2023
3902969
Fix input_boxes shapes
Rocketknight1 May 16, 2023
07813f0
make fixup
Rocketknight1 May 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
| RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
| RWKV | ❌ | ❌ | ✅ | ❌ | ❌ |
| SAM | ❌ | ❌ | ✅ | | ❌ |
| SAM | ❌ | ❌ | ✅ | | ❌ |
| SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
Expand Down
6 changes: 6 additions & 0 deletions docs/source/en/model_doc/sam.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,9 @@ Resources:

[[autodoc]] SamModel
- forward


## TFSamModel

[[autodoc]] TFSamModel
- call
12 changes: 12 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3405,6 +3405,13 @@
"TFRoFormerPreTrainedModel",
]
)
_import_structure["models.sam"].extend(
[
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
)
_import_structure["models.segformer"].extend(
[
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -6646,6 +6653,11 @@
TFRoFormerModel,
TFRoFormerPreTrainedModel,
)
from .models.sam import (
TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSamModel,
TFSamPreTrainedModel,
)
from .models.segformer import (
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSegformerDecodeHead,
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
("roberta", "TFRobertaModel"),
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"),
("sam", "TFSamModel"),
("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"),
("swin", "TFSwinModel"),
Expand Down Expand Up @@ -426,6 +427,11 @@
("mobilebert", "TFMobileBertForNextSentencePrediction"),
]
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam", "TFSamModel"),
]
)

TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
Expand Down Expand Up @@ -476,6 +482,14 @@
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)

TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
)


class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING


class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING
Expand Down
27 changes: 26 additions & 1 deletion src/transformers/models/sam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)


_import_structure = {
Expand All @@ -39,6 +45,17 @@
"SamModel",
"SamPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_sam"] = [
"TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSamModel",
"TFSamPreTrainedModel",
]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -66,6 +83,14 @@
else:
from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
Expand Down
Loading