Skip to content

Commit

Permalink
Adding SAM2 to the Model Zoo! (#4671)
Browse files Browse the repository at this point in the history
* Adding SAM2 into the model zoo

* Updating model tags

* Adding tests

* Cleaning up _forward_pass

* Updating package requirements

* Updating to work on CPU

* Refactoring code

* Removing 0 masks

* Fixing bug with single mask

* Doc changes + package requirements

* Updating manifest

* tweaks

* Adding initialize model._fields

---------

Co-authored-by: Prerna Dhareshwar <prerna@Prernas-MacBook-Pro.local>
Co-authored-by: brimoor <brimoor@umich.edu>
  • Loading branch information
3 people authored and benjaminpkane committed Aug 15, 2024
1 parent 4c56c14 commit 6a41ee6
Show file tree
Hide file tree
Showing 6 changed files with 993 additions and 25 deletions.
26 changes: 25 additions & 1 deletion docs/scripts/make_model_zoo_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@
import fiftyone as fo
import fiftyone.zoo as foz
{% if 'segment-anything' in name and 'video' in name %}
from fiftyone import ViewField as F
{% endif %}
{% if 'imagenet' in name %}
dataset = foz.load_zoo_dataset(
Expand All @@ -98,6 +101,16 @@
max_samples=50,
shuffle=True,
)
{% elif 'segment-anything' in name and 'video' in name %}
dataset = foz.load_zoo_dataset("quickstart-video", max_samples=2)
# Only retain detections in the first frame
(
dataset
.match_frames(F("frame_number") > 1)
.set_field("frames.detections", None)
.save()
)
{% else %}
dataset = foz.load_zoo_dataset(
"coco-2017",
Expand All @@ -108,7 +121,7 @@
)
{% endif %}
{% if 'segment-anything' in tags %}
{% if 'segment-anything' in tags and 'video' not in tags %}
model = foz.load_zoo_model("{{ name }}")
# Segment inside boxes
Expand All @@ -121,6 +134,17 @@
# Full automatic segmentations
dataset.apply_model(model, label_field="auto")
session = fo.launch_app(dataset)
{% elif 'segment-anything' in tags and 'video' in tags %}
model = foz.load_zoo_model("{{ name }}")
# Segment inside boxes and propagate to all frames
dataset.apply_model(
model,
label_field="segmentations",
prompt_field="frames.detections", # can contain Detections or Keypoints
)
session = fo.launch_app(dataset)
{% elif 'dinov2' in name %}
model = foz.load_zoo_model("{{ name }}")
Expand Down
47 changes: 31 additions & 16 deletions fiftyone/utils/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,20 @@ def _get_classes(self, samples, field_name):
return sorted(classes)

def _forward_pass(self, imgs):
if self._curr_prompt_type == "boxes":
return self._forward_pass_boxes(imgs)

if self._curr_prompt_type == "points":
return self._forward_pass_points(imgs)

return self._forward_pass_auto(imgs)
forward_methods = {
"boxes": self._forward_pass_boxes,
"points": self._forward_pass_points,
None: self._forward_pass_auto,
}
return forward_methods.get(
self._curr_prompt_type, self._forward_pass_auto
)(imgs)

def _load_predictor(self):
return sam.SamPredictor(self._model)

def _forward_pass_boxes(self, imgs):
sam_predictor = sam.SamPredictor(self._model)
sam_predictor = self._load_predictor()
self._output_processor = fout.InstanceSegmenterOutputProcessor(
self._curr_classes
)
Expand Down Expand Up @@ -280,7 +284,7 @@ def _forward_pass_boxes(self, imgs):
return outputs

def _forward_pass_points(self, imgs):
sam_predictor = sam.SamPredictor(self._model)
sam_predictor = self._load_predictor()
self._output_processor = fout.InstanceSegmenterOutputProcessor(
self._curr_classes
)
Expand All @@ -305,7 +309,7 @@ def _forward_pass_points(self, imgs):
continue

for kp in keypoints.keypoints:
sam_points, sam_labels = _to_sam_points(kp.points, w, h)
sam_points, sam_labels = _to_sam_points(kp.points, w, h, kp)

multi_mask, mask_scores, _ = sam_predictor.predict(
point_coords=sam_points,
Expand Down Expand Up @@ -341,9 +345,12 @@ def _forward_pass_points(self, imgs):

return outputs

def _forward_pass_auto(self, imgs):
def _load_auto_generator(self):
kwargs = self.config.auto_kwargs or {}
mask_generator = sam.SamAutomaticMaskGenerator(self._model, **kwargs)
return sam.SamAutomaticMaskGenerator(self._model, **kwargs)

def _forward_pass_auto(self, imgs):
mask_generator = self._load_auto_generator()
self._output_processor = None

outputs = []
Expand All @@ -367,10 +374,16 @@ def _to_sam_input(tensor):
return (255 * tensor.cpu().numpy()).astype("uint8").transpose(1, 2, 0)


def _to_sam_points(points, w, h, negative=False):
scaled_points = np.array(points) * np.array([w, h])
labels = np.zeros(len(points)) if negative else np.ones(len(points))
return scaled_points, labels
def _to_sam_points(points, w, h, keypoint):
points = np.array(points)
valid_rows = ~np.isnan(points).any(axis=1)
scaled_points = np.array(points[valid_rows]) * np.array([w, h])
labels = (
np.array(keypoint.sam2_labels)[valid_rows]
if "sam_labels" in keypoint
else np.ones(len(scaled_points))
)
return scaled_points.astype(np.float32), labels.astype(np.uint32)


def _to_sam_box(box, w, h):
Expand All @@ -386,6 +399,8 @@ def _to_sam_box(box, w, h):

def _mask_to_box(mask):
pos_indices = np.where(mask)
if all(arr.size == 0 for arr in pos_indices):
return None
x1 = np.min(pos_indices[1])
x2 = np.max(pos_indices[1])
y1 = np.min(pos_indices[0])
Expand Down
Loading

0 comments on commit 6a41ee6

Please sign in to comment.