Skip to content

Commit

Permalink
[docs] Adds segment anything document (#3451)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Sep 3, 2024
1 parent cde2221 commit f204abc
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The following examples are included for training:
The following examples are included for inference:

- [Image classification example](docs/image_classification.md)
- [Segment anything 2 example](docs/segment_anything_2.md)
- [Single-shot object detection example](docs/object_detection.md)
- [Face detection example](docs/face_detection.md)
- [Face recognition example](docs/face_recognition.md)
Expand Down
173 changes: 173 additions & 0 deletions examples/docs/segment_anything_2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Segment anything 2 example

[Mask generation](https://huggingface.co/tasks/mask-generation) is the task of generating masks that
identify a specific object or region of interest in a given image.

In this example, you learn how to implement inference code with a [ModelZoo model](../../docs/model-zoo.md) to
generate mask of a selected object in an image.

The source code can be found
at [SegmentAnything2.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java).

## Setup guide

To configure your development environment, follow [setup](../../docs/development/setup.md).

## Run segment anything 2 example

### Input image file

You can find the image used in this example:

![truck](https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg)

### Build the project and run

Use the following command to run the project:

```sh
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.cv.SegmentAnything2
```

Your output should look like the following:

```text
[INFO ] - Number of inter-op threads is 12
[INFO ] - Number of intra-op threads is 6
[INFO ] - Segmentation result image has been saved in: build/output/sam2.png
[INFO ] - [
{"class": "", "probability": 0.92789, "bounds": {"x"=0.000, "y"=0.000, "width"=1800.000, "height"=1200.000}}
]
```

An output image with bounding box will be saved as build/output/sam2.png:

![mask](https://resources.djl.ai/images/sam2_truck_1.png)

## Reference - how to import pytorch model:

The original model can be found:

- [sam2-hiera-large](https://huggingface.co/facebook/sam2-hiera-large)
- [sam2-hiera-tiny](https://huggingface.co/facebook/sam2-hiera-tiny)

The model zoo model was traced with `sam2==0.4.1` and `transformers==4.43.4`

### install dependencies

```bash
pip install sam2 transformers
```

### trace the model

```python
import sys
from typing import Tuple

import torch
from sam2.modeling.sam2_base import SAM2Base
from sam2.sam2_image_predictor import SAM2ImagePredictor
from torch import nn


class Sam2Wrapper(nn.Module):

def __init__(
self,
sam_model: SAM2Base,
) -> None:
super().__init__()
self.model = sam_model

# Spatial dim for backbone feature maps
self._bb_feat_sizes = [
(256, 256),
(128, 128),
(64, 64),
]

def extract_features(
self,
input_image: torch.Tensor,
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
backbone_out = self.model.forward_image(input_image)
_, vision_feats, _, _ = self.model._prepare_backbone_features(
backbone_out)
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
if self.model.directly_add_no_mem_embed:
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed

feats = [
feat.permute(1, 2,
0).view(1, -1, *feat_size) for feat, feat_size in zip(
vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]

return feats[-1], feats[0], feats[1]

def forward(
self,
input_image: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
image_embed, feature_1, feature_2 = self.extract_features(input_image)
return self.predict(point_coords, point_labels, image_embed, feature_1,
feature_2)

def predict(
self,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
image_embed: torch.Tensor,
feats_1: torch.Tensor,
feats_2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
concat_points = (point_coords, point_labels)

sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
points=concat_points,
boxes=None,
masks=None,
)

low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
image_embeddings=image_embed[0].unsqueeze(0),
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True,
repeat_image=False,
high_res_features=[feats_1, feats_2],
)
return low_res_masks, iou_predictions


def trace_model(model_id: str):
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")

predictor = SAM2ImagePredictor.from_pretrained(model_id, device=device)
model = Sam2Wrapper(predictor.model)

input_image = torch.ones(1, 3, 1024, 1024).to(device)
input_point = torch.ones(1, 1, 2).to(device)
input_labels = torch.ones(1, 1, dtype=torch.int32, device=device)

converted = torch.jit.trace_module(
model, {
"extract_features": input_image,
"forward": (input_image, input_point, input_labels)
})
torch.jit.save(converted, f"{model_id[9:]}.pt")


if __name__ == '__main__':
hf_model_id = sys.argv[1] if len(
sys.argv) > 1 else "facebook/sam2-hiera-tiny"
trace_model(hf_model_id)
```
119 changes: 119 additions & 0 deletions examples/docs/trace_sam2_img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import sys
from typing import Tuple

import torch
from sam2.modeling.sam2_base import SAM2Base
from sam2.sam2_image_predictor import SAM2ImagePredictor
from torch import nn


class Sam2Wrapper(nn.Module):

def __init__(
self,
sam_model: SAM2Base,
) -> None:
super().__init__()
self.model = sam_model

# Spatial dim for backbone feature maps
self._bb_feat_sizes = [
(256, 256),
(128, 128),
(64, 64),
]

def extract_features(
self,
input_image: torch.Tensor,
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
backbone_out = self.model.forward_image(input_image)
_, vision_feats, _, _ = self.model._prepare_backbone_features(
backbone_out)
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
if self.model.directly_add_no_mem_embed:
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed

feats = [
feat.permute(1, 2,
0).view(1, -1, *feat_size) for feat, feat_size in zip(
vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]

return feats[-1], feats[0], feats[1]

def forward(
self,
input_image: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
image_embed, feature_1, feature_2 = self.extract_features(input_image)
return self.predict(point_coords, point_labels, image_embed, feature_1,
feature_2)

def predict(
self,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
image_embed: torch.Tensor,
feats_1: torch.Tensor,
feats_2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
concat_points = (point_coords, point_labels)

sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
points=concat_points,
boxes=None,
masks=None,
)

low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
image_embeddings=image_embed[0].unsqueeze(0),
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True,
repeat_image=False,
high_res_features=[feats_1, feats_2],
)
return low_res_masks, iou_predictions


def trace_model(model_id: str):
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")

predictor = SAM2ImagePredictor.from_pretrained(model_id, device=device)
model = Sam2Wrapper(predictor.model)

input_image = torch.ones(1, 3, 1024, 1024).to(device)
input_point = torch.ones(1, 1, 2).to(device)
input_labels = torch.ones(1, 1, dtype=torch.int32, device=device)

converted = torch.jit.trace_module(
model, {
"extract_features": input_image,
"forward": (input_image, input_point, input_labels)
})
torch.jit.save(converted, f"{model_id[9:]}.pt")


if __name__ == '__main__':
hf_model_id = sys.argv[1] if len(
sys.argv) > 1 else "facebook/sam2-hiera-tiny"
trace_model(hf_model_id)

0 comments on commit f204abc

Please sign in to comment.