Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding SAM2 to the Model Zoo! #4671

Merged
merged 13 commits into from
Aug 15, 2024
Merged

Adding SAM2 to the Model Zoo! #4671

merged 13 commits into from
Aug 15, 2024

Conversation

prernadh
Copy link
Contributor

@prernadh prernadh commented Aug 13, 2024

What changes are proposed in this pull request?

Adding Segment Anything 2 to the Fiftyone model zoo.

How is this patch tested? If it is not, please explain why.

Tested manually with different configurations -

  1. Images - prompted with bounding boxes, prompted with keypoints and no prompts
  2. Videos - prompted with bounding boxes and prompted with keypoints

Release Notes

Is this a user-facing change that should be mentioned in the release notes?

  • No. You can skip the rest of this section.
  • Yes. Give a description of this change to be included in the release
    notes for FiftyOne users.

Added SAM2 into the Fiftyone model zoo with inference support for both images and videos.

What areas of FiftyOne does this PR affect?

  • App: FiftyOne application changes
  • Build: Build and test infrastructure changes
  • Core: Core fiftyone Python library changes
  • Documentation: FiftyOne documentation changes
  • Other

Box prompt for Images

import fiftyone as fo
import fiftyone.zoo as foz

dataset = foz.load_zoo_dataset(
    "quickstart", max_samples=25, shuffle=True, seed=51
)

model = foz.load_zoo_model("segment-anything-2-hiera-tiny-image-torch")

# Prompt with boxes
dataset.apply_model(
    model,
    label_field="segmentations",
    prompt_field="ground_truth",
)

session = fo.launch_app(dataset)

Keypoint prompt for Images

import fiftyone as fo
import fiftyone.zoo as foz
from fiftyone import ViewField as F

dataset = foz.load_zoo_dataset("quickstart")
dataset = dataset.filter_labels("ground_truth", F("label") == "person")


# Generate some keypoints
model = foz.load_zoo_model("keypoint-rcnn-resnet50-fpn-coco-torch")
dataset.default_skeleton = model.skeleton
dataset.apply_model(model, label_field="gt")

model = foz.load_zoo_model("segment-anything-2-hiera-tiny-image-torch")

# Prompt with keypoints
dataset.apply_model(
    model,
    label_field="segmentations",
    prompt_field="gt_keypoints",
)

session = fo.launch_app(dataset)

Automatic segmentation for Images

import fiftyone as fo
import fiftyone.zoo as foz

dataset = foz.load_zoo_dataset(
    "quickstart", max_samples=5, shuffle=True, seed=51
)

model = foz.load_zoo_model("segment-anything-2-hiera-tiny-image-torch")

# Automatic segmentation
dataset.apply_model(model, label_field="auto")

session = fo.launch_app(dataset)

Prompting for Videos

import fiftyone as fo
import fiftyone.zoo as foz

dataset = foz.load_zoo_dataset("quickstart-video", max_samples=2)

# Only retain detections on the first frame of each video
for sample in dataset:
    for frame_idx in sample.frames:
        frame = sample.frames[frame_idx]
        if frame_idx >= 2:
            frame.detections = None
        sample.save()

model = foz.load_zoo_model("segment-anything-2-hiera-tiny-video-torch")

# Prompt with boxes
dataset.apply_model(
    model,
    label_field="segmentations",
    prompt_field="frames.detections", # You can also pass in a keypoint field here
)

session = fo.launch_app(dataset)

Summary by CodeRabbit

  • New Features

    • Introduced advanced image and video segmentation capabilities using the Segment Anything 2 (SAM2) model.
    • Added tools for users to apply segmentation through various prompting methods including bounding boxes and keypoints.
  • Improvements

    • Enhanced the functionality and maintainability of segmentation methods with a simplified control flow.
    • Improved handling for extracting labels in keypoint processing, enhancing robustness.
  • Bug Fixes

    • Refined dataset handling for video and non-video data, ensuring appropriate model application based on dataset characteristics.
  • Chores

    • Updated linting configuration to accommodate OpenCV functionalities and prevent false positive warnings.

Copy link
Contributor

coderabbitai bot commented Aug 13, 2024

Walkthrough

The recent updates enhance the FiftyOne framework by incorporating advanced image and video segmentation capabilities through the Segment Anything 2 (SAM2) model. Key additions include new classes for processing images and videos, improved input handling, and expanded testing functionalities. Additionally, refinements to existing components streamline the code's maintainability and flexibility, allowing for easier future enhancements.

Changes

Files Change Summary
fiftyone/utils/sam.py, fiftyone/utils/sam2.py, tests/intensive/model_zoo_tests.py Introduced classes for image and video segmentation; improved input handling and testing framework for various prompt types.
pylintrc Updated to include cv2 for better linting and dynamic member recognition; improved handling of OpenCV functionalities.

Poem

In the meadow where pixels bloom,
A rabbit hops, dispelling gloom.
With SAM2's magic, we now can see,
Segments of joy, wild and free!
Hopping through frames, what a delight,
Image and video, all in sight! 🐰✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@prernadh prernadh marked this pull request as ready for review August 13, 2024 22:02
@prernadh prernadh requested review from jacobmarks and brimoor August 13, 2024 22:03
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Outside diff range, codebase verification and nitpick comments (5)
fiftyone/utils/sam2.py (5)

150-165: Ensure model is correctly loaded and evaluated.

The _load_model method correctly loads and evaluates the model on the specified device. Consider logging the model's device and precision for debugging purposes.

logger.info(f"Model loaded on device: {self._device}, using half precision: {self.using_half_precision}")

667-676: Consider logging NaN handling in points.

The function handles NaN values in points. Consider logging the number of filtered points for debugging.

logger.debug(f"Filtered {len(points) - len(scaled_points)} NaN points.")

690-698: Consider logging when no box is found.

The function correctly returns None when no box is found. Consider logging this case for debugging.

if all(arr.size == 0 for arr in pos_indices):
    logger.debug("No positive indices found in mask.")
    return None

700-738: Improve error handling for frame reading.

The function should handle potential errors during frame reading and resizing more gracefully.

try:
    current_frame = video_reader.read()
    if current_frame is None:
        raise ValueError("Failed to read frame from video.")
    resized_frame = cv2.resize(current_frame, (image_size, image_size)) / 255.0
except Exception as e:
    logger.error(f"Error reading or resizing frame: {e}")
    raise

167-180: Ensure safe handling of None values for prompts and classes

The variables _curr_prompt_type, _curr_prompts, and _curr_classes are used in conditions and loops in fiftyone/utils/sam2.py. If these variables are None, the code might raise exceptions or behave unexpectedly. Consider adding checks to handle None values safely.

  • File: fiftyone/utils/sam2.py
    • Ensure conditions and loops using _curr_prompt_type, _curr_prompts, and _curr_classes handle None values appropriately.
Analysis chain

Handle missing samples gracefully.

In predict_all, if samples is None, ensure that downstream methods can handle None values for prompts and classes without errors.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify handling of None samples in predict_all method.

# Test: Search for occurrences of predict_all and ensure samples is handled.
rg --type python -A 5 'predict_all'

Length of output: 66


Script:

#!/bin/bash
# Find the implementation of the _predict_all method to verify handling of None values.
ast-grep --lang python --pattern $'def _predict_all(self, $_):\n    $$$'

Length of output: 6523


Script:

#!/bin/bash
# Search for occurrences of _curr_prompt_type, _curr_prompts, and _curr_classes to verify handling of None values.
rg --type python '_curr_prompt_type|_curr_prompts|_curr_classes'

Length of output: 97


Script:

#!/bin/bash
# Corrected search for occurrences of _curr_prompt_type, _curr_prompts, and _curr_classes to verify handling of None values.
rg --type py '_curr_prompt_type|_curr_prompts|_curr_classes'

Length of output: 2643

Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL

Commits

Files that changed from the base of the PR and between c0bc6b0 and b7b8a0d.

Files ignored due to path filters (1)
  • fiftyone/zoo/models/manifest-torch.json is excluded by !**/*.json
Files selected for processing (1)
  • fiftyone/utils/sam2.py (1 hunks)
Additional comments not posted (3)
fiftyone/utils/sam2.py (3)

68-71: Class structure is appropriate.

The SegmentAnything2VideoModelConfig class is well-structured for potential future extensions.


664-666: Function is efficient and correct.

The _to_sam_input function efficiently converts tensors for SAM input.


679-687: Function is efficient and correct.

The _to_sam_box function efficiently converts bounding boxes for SAM input.

Comment on lines 244 to 251
def _forward_pass(self, imgs):
if self._curr_prompt_type == "boxes":
return self._forward_pass_boxes(imgs)

if self._curr_prompt_type == "points":
return self._forward_pass_points(imgs)

return self._forward_pass_auto(imgs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor prompt type handling for clarity.

The _forward_pass method could benefit from a clearer structure by using a dictionary to map prompt types to their respective methods.

forward_methods = {
    "boxes": self._forward_pass_boxes,
    "points": self._forward_pass_points,
    None: self._forward_pass_auto,
}
return forward_methods.get(self._curr_prompt_type, self._forward_pass_auto)(imgs)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL

Commits

Files that changed from the base of the PR and between b7b8a0d and 818c662.

Files selected for processing (1)
  • tests/intensive/model_zoo_tests.py (4 hunks)
Additional context used
Ruff
tests/intensive/model_zoo_tests.py

247-250: Use ternary operator kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {} instead of if-else-block

Replace if-else-block with kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}

(SIM108)

Additional comments not posted (6)
tests/intensive/model_zoo_tests.py (6)

78-85: New test function test_sam2_boxes looks good.

The function is correctly set up to test the SAM2 model with box prompts.


88-97: New test function test_sam2_points looks good.

The function is correctly set up to test the SAM2 model with keypoint prompts.


100-105: New test function test_sam2_auto looks good.

The function is correctly set up to test the SAM2 model with automatic segmentation.


107-113: New test function test_sam2_video looks good.

The function is correctly set up to test the SAM2 model with video prompts.


Line range hint 192-220:
Modification to _apply_models is well-integrated.

The new prompt_type parameter is correctly utilized to handle keypoint prompts.

Tools
Ruff

213-214: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


247-250: Use ternary operator kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {} instead of if-else-block

Replace if-else-block with kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}

(SIM108)


238-268: Function _apply_video_models is well-structured.

The function correctly handles video datasets and integrates the necessary logic for testing.

Tools
Ruff

247-250: Use ternary operator kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {} instead of if-else-block

Replace if-else-block with kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}

(SIM108)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 818c662 and 19095a6.

Files selected for processing (1)
  • fiftyone/utils/sam2.py (1 hunks)
Files skipped from review as they are similar to previous changes (1)
  • fiftyone/utils/sam2.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 19095a6 and f6f90ab.

Files ignored due to path filters (1)
  • fiftyone/zoo/models/manifest-torch.json is excluded by !**/*.json
Files selected for processing (2)
  • fiftyone/utils/sam2.py (1 hunks)
  • pylintrc (2 hunks)
Additional context used
Ruff
fiftyone/utils/sam2.py

737-737: Do not perform function call torch.device in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)

Additional comments not posted (8)
pylintrc (2)

35-35: Approve addition of cv2 to extension-pkg-whitelist.

This change allows the use of OpenCV functionalities by permitting the loading of the cv2 C extension, which is necessary for image and video processing tasks.


293-293: Approve inclusion of cv2.* in generated-members.

This addition helps prevent false positives by recognizing dynamic members from the cv2 module, improving the linting process.

fiftyone/utils/sam2.py (6)

61-73: Approve SegmentAnything2VideoModelConfig initialization.

The constructor correctly initializes the configuration by calling the superclass constructor. No additional changes are necessary.


692-694: Approve _to_sam_input function.

The function correctly converts a tensor to a SAM input format, ensuring compatibility with the SAM2 model.


696-705: Approve _to_sam_points function.

The function effectively processes keypoints, handling NaN values and scaling them appropriately for SAM input. It also manages labels correctly.


708-716: Approve _to_sam_box function.

The function accurately transforms bounding boxes to SAM format by scaling and adjusting coordinates.


719-727: Approve _mask_to_box function.

The function efficiently computes bounding boxes from masks, handling cases where no positive indices are found.


772-778: Approve _load_video_frames_monkey_patch function.

The function correctly implements a monkey patch for video frame loading, ensuring custom functionality is applied.

Comment on lines +49 to +60
def __init__(self, d):
d = self.init(d)
super().__init__(d)

self.auto_kwargs = self.parse_dict(d, "auto_kwargs", default=None)
self.points_mask_index = self.parse_int(
d, "points_mask_index", default=None
)
if self.points_mask_index and not 0 <= self.points_mask_index <= 2:
raise ValueError("mask_index must be 0, 1, or 2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validate points_mask_index more robustly.

The validation for points_mask_index should ensure it is an integer before checking its range. Consider using isinstance to check the type.

if self.points_mask_index is not None:
    if not isinstance(self.points_mask_index, int) or not 0 <= self.points_mask_index <= 2:
        raise ValueError("mask_index must be an integer between 0 and 2")

Comment on lines 249 to 257
def _forward_pass(self, imgs):
forward_methods = {
"boxes": self._forward_pass_boxes,
"points": self._forward_pass_points,
None: self._forward_pass_auto,
}
return forward_methods.get(
self._curr_prompt_type, self._forward_pass_auto
)(imgs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor prompt type handling for clarity.

The _forward_pass method could benefit from a clearer structure by using a dictionary to map prompt types to their respective methods.

forward_methods = {
    "boxes": self._forward_pass_boxes,
    "points": self._forward_pass_points,
    None: self._forward_pass_auto,
}
return forward_methods.get(self._curr_prompt_type, self._forward_pass_auto)(imgs)

Comment on lines +480 to +335
def _get_field(self):
if "prompt_field" in self.needs_fields:
prompt_field = self.needs_fields["prompt_field"]
else:
prompt_field = next(iter(self.needs_fields.values()), None)

if not prompt_field.startswith("frames."):
raise ValueError(
"'prompt_field' should be a frame field for segment anything 2 video model"
)

if prompt_field is None:
raise AttributeError(
"Missing required argument 'prompt_field' for segment anything 2 video model"
)

prompt_field = prompt_field[len("frames.") :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure prompt_field is not None before checking its prefix.

The _get_field method should first check if prompt_field is None before verifying its prefix to prevent potential errors.

if prompt_field is None:
    raise AttributeError(
        "Missing required argument 'prompt_field' for segment anything 2 video model"
)
if not prompt_field.startswith("frames."):
    raise ValueError(
        "'prompt_field' should be a frame field for segment anything 2 video model"
    )

Comment on lines +730 to +570
def load_fiftyone_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
sample, video_reader = video_path
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]

num_frames = len(sample.frames)
try:
images = torch.zeros(
num_frames, 3, image_size, image_size, dtype=torch.float32
)
except Exception as e:
raise (e)
for frame_number in range(num_frames):
current_frame = video_reader.read()
resized_frame = (
cv2.resize(current_frame, (image_size, image_size)) / 255.0
)
img = torch.from_numpy(resized_frame).permute(2, 0, 1)
images[frame_number] = img

video_width, video_height = (
current_frame.shape[1],
current_frame.shape[0],
)
if not offload_video_to_cpu:
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
img_std = img_std.to(compute_device)

images -= img_mean
images /= img_std
return images, video_height, video_width
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid function call in argument defaults.

The torch.device call in the compute_device argument default should be moved inside the function to avoid potential issues.

def load_fiftyone_video_frames(
    video_path,
    image_size,
    offload_video_to_cpu,
    img_mean=(0.485, 0.456, 0.406),
    img_std=(0.229, 0.224, 0.225),
    async_loading_frames=False,
    compute_device=None,
):
    if compute_device is None:
        compute_device = torch.device("cuda")
    ...
Tools
Ruff

737-737: Do not perform function call torch.device in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Outside diff range, codebase verification and nitpick comments (1)
fiftyone/utils/sam2.py (1)

50-59: Validate points_mask_index more robustly.

Ensure points_mask_index is an integer before checking its range. Consider using isinstance to check the type.

- if self.points_mask_index and not 0 <= self.points_mask_index <= 2:
+ if self.points_mask_index is not None:
+     if not isinstance(self.points_mask_index, int) or not 0 <= self.points_mask_index <= 2:
+         raise ValueError("mask_index must be an integer between 0 and 2")
Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL

Commits

Files that changed from the base of the PR and between f6f90ab and 6bbfe52.

Files selected for processing (2)
  • fiftyone/utils/sam.py (5 hunks)
  • fiftyone/utils/sam2.py (1 hunks)
Additional context used
Ruff
fiftyone/utils/sam2.py

531-531: Do not perform function call torch.device in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)

Additional comments not posted (8)
fiftyone/utils/sam.py (6)

224-231: Refactor improves maintainability.

The use of a dictionary to map prompt types to their handling methods simplifies the control flow and enhances maintainability.


233-234: Encapsulation promotes reuse.

The introduction of _load_predictor encapsulates the creation of a SamPredictor instance, promoting code reuse and reducing redundancy.


237-237: Consistent predictor instantiation.

Using _load_predictor for predictor instantiation reduces code duplication and maintains consistency.


287-287: Consistent predictor instantiation and enhanced point handling.

Using _load_predictor maintains consistency, and changes to _to_sam_points enhance its capability to handle different input scenarios.


348-350: Streamlined generator instantiation.

The introduction of _load_auto_generator ensures consistent object creation and initialization.


377-386: Enhanced flexibility in point processing.

The modification to _to_sam_points to accept a keypoint parameter and conditionally extract labels enhances the function's flexibility and robustness.

fiftyone/utils/sam2.py (2)

76-218: Well-structured image model implementation.

The SegmentAnything2ImageModel class is well-organized, with clear methods for loading models and processing prompts, consistent with its intended functionality.


225-443: Comprehensive video model implementation with robust validation.

The SegmentAnything2VideoModel class is comprehensive, with well-defined methods for handling video prompts and processing. The _get_field method includes robust validation checks.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 6bbfe52 and d43976a.

Files selected for processing (1)
  • fiftyone/utils/sam.py (6 hunks)
Additional comments not posted (6)
fiftyone/utils/sam.py (6)

224-231: Great use of dictionary for method mapping!

The refactor using a dictionary to map prompt types to their corresponding methods improves readability and maintainability.


233-234: Encapsulation of predictor creation is a good move!

The introduction of _load_predictor centralizes the creation of the SamPredictor, promoting code reuse and reducing redundancy.


348-350: Centralized generator instantiation is beneficial!

The introduction of _load_auto_generator centralizes the creation of the SamAutomaticMaskGenerator, promoting consistency and modularity.


377-385: Enhanced flexibility in handling keypoints!

The modification to _to_sam_points to accept a keypoint parameter and handle "sam_labels" enhances its flexibility and robustness.


287-287: Consistent handling of empty keypoints!

The handling of cases with no keypoints by returning empty tensors is consistent with the approach in _forward_pass_boxes and is efficient.

Ensure that this behavior is correctly integrated and tested across the codebase.

Verification successful

Consistent and efficient handling of empty keypoints verified!

The _forward_pass_points function in both fiftyone/utils/sam.py and fiftyone/utils/sam2.py handles empty keypoints appropriately. In fiftyone/utils/sam.py, it returns empty tensors, and in fiftyone/utils/sam2.py, it skips processing. This confirms the review comment's observation about efficient handling.

  • fiftyone/utils/sam.py: Returns empty tensors when no keypoints are present.
  • fiftyone/utils/sam2.py: Skips processing when no keypoints are present.
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify handling of empty keypoints in `_forward_pass_points`.

# Test: Search for the function usage. Expect: Proper handling of empty keypoints.
rg --type python -A 5 $'_forward_pass_points'

Length of output: 75


Script:

#!/bin/bash
# Description: Verify handling of empty keypoints in `_forward_pass_points`.

# Test: Search for the function usage. Expect: Proper handling of empty keypoints.
rg --type py -A 5 $'_forward_pass_points'

Length of output: 1462


Script:

#!/bin/bash
# Description: Verify handling of empty keypoints in `_forward_pass_points` in both implementations.

# Search for the implementation details of `_forward_pass_points` in `fiftyone/utils/sam.py`
ast-grep --lang python --pattern $'def _forward_pass_points($_):\n    $$$' fiftyone/utils/sam.py

# Search for the implementation details of `_forward_pass_points` in `fiftyone/utils/sam2.py`
ast-grep --lang python --pattern $'def _forward_pass_points($_):\n    $$$' fiftyone/utils/sam2.py

Length of output: 194


Script:

#!/bin/bash
# Description: Verify handling of empty keypoints in `_forward_pass_points` using `rg`.

# Search for the `_forward_pass_points` function definition in `fiftyone/utils/sam.py` with context
rg --type py -A 20 'def _forward_pass_points' fiftyone/utils/sam.py

# Search for the `_forward_pass_points` function definition in `fiftyone/utils/sam2.py` with context
rg --type py -A 20 'def _forward_pass_points' fiftyone/utils/sam2.py

Length of output: 1881


237-237: Efficient handling of empty detections!

The handling of cases with no detections by returning empty tensors is efficient and prevents unnecessary computations.

Ensure that this behavior is correctly integrated and tested across the codebase.

Verification successful

Efficient handling of empty detections confirmed!

The _forward_pass_boxes function in both fiftyone/utils/sam.py and fiftyone/utils/sam2.py efficiently handles cases with no detections by returning empty tensors, thus preventing unnecessary computations. This behavior is correctly integrated into the codebase.

  • fiftyone/utils/sam.py
  • fiftyone/utils/sam2.py
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify handling of empty detections in `_forward_pass_boxes`.

# Test: Search for the function usage. Expect: Proper handling of empty detections.
rg --type python -A 5 $'_forward_pass_boxes'

Length of output: 74


Script:

#!/bin/bash
# Description: Verify handling of empty detections in `_forward_pass_boxes`.

# Test: Search for the function usage. Expect: Proper handling of empty detections.
rg --type py -A 5 $'_forward_pass_boxes'

Length of output: 1758


Script:

#!/bin/bash
# Extract the full implementation of `_forward_pass_boxes` from `fiftyone/utils/sam.py`
ast-grep --lang python --pattern $'def _forward_pass_boxes($_) { $$$ }' fiftyone/utils/sam.py

# Extract the full implementation of `_forward_pass_boxes` from `fiftyone/utils/sam2.py`
ast-grep --lang python --pattern $'def _forward_pass_boxes($_) { $$$ }' fiftyone/utils/sam2.py

Length of output: 1106


Script:

#!/bin/bash
# Correct the pattern and extract the full implementation of `_forward_pass_boxes` from `fiftyone/utils/sam.py`
ast-grep --lang python --pattern 'def _forward_pass_boxes($_):' fiftyone/utils/sam.py

# Correct the pattern and extract the full implementation of `_forward_pass_boxes` from `fiftyone/utils/sam2.py`
ast-grep --lang python --pattern 'def _forward_pass_boxes($_):' fiftyone/utils/sam2.py

Length of output: 176


Script:

#!/bin/bash
# Use `rg` to search for the function definition and include more lines of context to capture the full implementation.
rg --type py -A 20 'def _forward_pass_boxes' fiftyone/utils/sam.py
rg --type py -A 20 'def _forward_pass_boxes' fiftyone/utils/sam2.py

Length of output: 2805

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL

Commits

Files that changed from the base of the PR and between d43976a and 89f9d91.

Files selected for processing (1)
  • fiftyone/utils/sam2.py (1 hunks)
Additional context used
Ruff
fiftyone/utils/sam2.py

533-533: Do not perform function call torch.device in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)

Additional comments not posted (6)
fiftyone/utils/sam2.py (6)

50-59: Enhance points_mask_index validation.

The validation for points_mask_index should ensure it is an integer before checking its range. Consider using isinstance to check the type.

if self.points_mask_index is not None:
    if not isinstance(self.points_mask_index, int) or not 0 <= self.points_mask_index <= 2:
        raise ValueError("mask_index must be an integer between 0 and 2")

62-74: LGTM!

The SegmentAnything2VideoModelConfig class is well-defined and straightforward.


149-167: LGTM!

The model loading and predictor initialization methods are implemented correctly.


314-330: Ensure prompt_field is not None before checking its prefix.

The _get_field method should first check if prompt_field is None before verifying its prefix to prevent potential errors.

if prompt_field is None:
    raise AttributeError(
        "Missing required argument 'prompt_field' for segment anything 2 video model"
)
if not prompt_field.startswith("frames."):
    raise ValueError(
        "'prompt_field' should be a frame field for segment anything 2 video model"
    )

526-534: Avoid function call in argument defaults.

Move the torch.device call inside the function to avoid potential issues.

- compute_device=torch.device("cuda"),
+ compute_device=None,
...
+ if compute_device is None:
+     compute_device = torch.device("cuda")
Tools
Ruff

533-533: Do not perform function call torch.device in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)


568-574: LGTM!

The _load_video_frames_monkey_patch function is correctly implemented.

Comment on lines +365 to +374
def _forward_pass(self, video_reader, sample):
if self._curr_prompt_type == "boxes":
return self._forward_pass_boxes(video_reader, sample)
elif self._curr_prompt_type == "points":
return self._forward_pass_points(video_reader, sample)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle unsupported prompt types in _forward_pass.

Consider adding a default case to handle unsupported prompt types gracefully.

else:
    raise ValueError(f"Unsupported prompt type: {self._curr_prompt_type}")

Comment on lines +168 to +221
def _forward_pass_boxes(self, imgs):
sam2_predictor = self._load_predictor()
self._output_processor = fout.InstanceSegmenterOutputProcessor(
self._curr_classes
)
outputs = []
for img, detections in zip(imgs, self._curr_prompts):
## If no detections, return empty tensors instead of running SAM
if detections is None or len(detections.detections) == 0:
h, w = img.shape[1], img.shape[2]
outputs.append(
{
"boxes": torch.tensor([[]]),
"labels": torch.empty([0, 4]),
"masks": torch.empty([0, 1, h, w]),
}
)
continue
inp = fosam._to_sam_input(img)
sam2_predictor.set_image(inp)
h, w = img.size(1), img.size(2)

boxes = [d.bounding_box for d in detections.detections]
sam_boxes = np.array(
[fosam._to_sam_box(box, w, h) for box in boxes]
)
input_boxes = torch.tensor(sam_boxes, device=sam2_predictor.device)

labels = torch.tensor(
[
self._curr_classes.index(d.label)
for d in detections.detections
],
device=sam2_predictor.device,
)

masks, _, _ = sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=sam_boxes[None, :],
multimask_output=False,
)
if masks.ndim == 3:
masks = np.expand_dims(masks, axis=0)
outputs.append(
{
"boxes": input_boxes,
"labels": labels,
"masks": torch.tensor(masks, device=sam2_predictor.device),
}
)

return outputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize empty tensor handling in _forward_pass_boxes.

The handling of empty detections can be optimized by using torch.empty with the correct device directly.

- h, w = img.shape[1], img.shape[2]
+ device = sam2_predictor.device
  outputs.append(
      {
          "boxes": torch.tensor([[]], device=device),
          "labels": torch.empty([0, 4], device=device),
          "masks": torch.empty([0, 1, h, w], device=device),
      }
  )
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _forward_pass_boxes(self, imgs):
sam2_predictor = self._load_predictor()
self._output_processor = fout.InstanceSegmenterOutputProcessor(
self._curr_classes
)
outputs = []
for img, detections in zip(imgs, self._curr_prompts):
## If no detections, return empty tensors instead of running SAM
if detections is None or len(detections.detections) == 0:
h, w = img.shape[1], img.shape[2]
outputs.append(
{
"boxes": torch.tensor([[]]),
"labels": torch.empty([0, 4]),
"masks": torch.empty([0, 1, h, w]),
}
)
continue
inp = fosam._to_sam_input(img)
sam2_predictor.set_image(inp)
h, w = img.size(1), img.size(2)
boxes = [d.bounding_box for d in detections.detections]
sam_boxes = np.array(
[fosam._to_sam_box(box, w, h) for box in boxes]
)
input_boxes = torch.tensor(sam_boxes, device=sam2_predictor.device)
labels = torch.tensor(
[
self._curr_classes.index(d.label)
for d in detections.detections
],
device=sam2_predictor.device,
)
masks, _, _ = sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=sam_boxes[None, :],
multimask_output=False,
)
if masks.ndim == 3:
masks = np.expand_dims(masks, axis=0)
outputs.append(
{
"boxes": input_boxes,
"labels": labels,
"masks": torch.tensor(masks, device=sam2_predictor.device),
}
)
return outputs
def _forward_pass_boxes(self, imgs):
sam2_predictor = self._load_predictor()
self._output_processor = fout.InstanceSegmenterOutputProcessor(
self._curr_classes
)
outputs = []
for img, detections in zip(imgs, self._curr_prompts):
## If no detections, return empty tensors instead of running SAM
if detections is None or len(detections.detections) == 0:
device = sam2_predictor.device
h, w = img.shape[1], img.shape[2]
outputs.append(
{
"boxes": torch.tensor([[]], device=device),
"labels": torch.empty([0, 4], device=device),
"masks": torch.empty([0, 1, h, w], device=device),
}
)
continue
inp = fosam._to_sam_input(img)
sam2_predictor.set_image(inp)
h, w = img.size(1), img.size(2)
boxes = [d.bounding_box for d in detections.detections]
sam_boxes = np.array(
[fosam._to_sam_box(box, w, h) for box in boxes]
)
input_boxes = torch.tensor(sam_boxes, device=sam2_predictor.device)
labels = torch.tensor(
[
self._curr_classes.index(d.label)
for d in detections.detections
],
device=sam2_predictor.device,
)
masks, _, _ = sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=sam_boxes[None, :],
multimask_output=False,
)
if masks.ndim == 3:
masks = np.expand_dims(masks, axis=0)
outputs.append(
{
"boxes": input_boxes,
"labels": labels,
"masks": torch.tensor(masks, device=sam2_predictor.device),
}
)
return outputs

Copy link
Contributor

@jacobmarks jacobmarks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍 Works on my end :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like these changes. Unification and improved readability for the win!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 89f9d91 and d5eb1c3.

Files ignored due to path filters (1)
  • fiftyone/zoo/models/manifest-torch.json is excluded by !**/*.json
Files selected for processing (2)
  • docs/scripts/make_model_zoo_docs.py (3 hunks)
  • fiftyone/utils/sam2.py (1 hunks)
Additional context used
Ruff
fiftyone/utils/sam2.py

532-532: Do not perform function call torch.device in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)

Additional comments not posted (8)
docs/scripts/make_model_zoo_docs.py (2)

101-109: LGTM! Handling of video datasets is correct.

The logic for loading and processing video datasets is well-implemented and meets the specified requirements.


Line range hint 120-143: LGTM! Enhanced handling of 'segment-anything' tags.

The changes effectively differentiate between video and non-video datasets, applying models with appropriate parameters.

fiftyone/utils/sam2.py (6)

50-59: Validate points_mask_index more robustly.

Ensure points_mask_index is an integer before checking its range. Consider using isinstance to check the type.


62-69: LGTM! Configuration class for video model is appropriate.

The class serves its purpose as a configuration holder for the SAM2 video model.


168-220: Optimize empty tensor handling in _forward_pass_boxes.

The handling of empty detections can be optimized by using torch.empty with the correct device directly.


313-329: Ensure prompt_field is not None before checking its prefix.

The _get_field method should first check if prompt_field is None before verifying its prefix to prevent potential errors.


364-368: Handle unsupported prompt types in _forward_pass.

Consider adding a default case to handle unsupported prompt types gracefully.


525-533: Avoid function call in argument defaults.

Move the torch.device call inside the function to avoid potential issues.

Tools
Ruff

532-532: Do not perform function call torch.device in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)

@brimoor brimoor changed the base branch from develop to release/v0.25.0 August 15, 2024 03:04
@brimoor brimoor changed the base branch from release/v0.25.0 to develop August 15, 2024 03:05
Copy link
Contributor

@brimoor brimoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great work! 🥇

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL

Commits

Files that changed from the base of the PR and between d5eb1c3 and 33a844e.

Files ignored due to path filters (1)
  • fiftyone/zoo/models/manifest-torch.json is excluded by !**/*.json
Files selected for processing (3)
  • docs/scripts/make_model_zoo_docs.py (4 hunks)
  • fiftyone/utils/sam2.py (1 hunks)
  • tests/intensive/model_zoo_tests.py (4 hunks)
Files skipped from review as they are similar to previous changes (1)
  • docs/scripts/make_model_zoo_docs.py
Additional context used
Ruff
tests/intensive/model_zoo_tests.py

250-253: Use ternary operator kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {} instead of if-else-block

Replace if-else-block with kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}

(SIM108)

fiftyone/utils/sam2.py

538-538: Do not perform function call torch.device in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)

Additional comments not posted (9)
tests/intensive/model_zoo_tests.py (5)

79-86: LGTM!

The test_sam2_boxes function correctly sets up the test for the SAM2 model with box prompts.


89-98: LGTM!

The test_sam2_points function is correctly configured to test the SAM2 model with keypoints.


101-106: LGTM!

The test_sam2_auto function is correctly set up for automatic segmentation testing with the SAM2 model.


109-115: LGTM!

The test_sam2_video function is correctly configured to test video segmentation with the SAM2 model.


Line range hint 194-223: LGTM! Verify the function usage in the codebase.

The addition of the prompt_type parameter enhances the flexibility of the _apply_models function.

Ensure that the parameter is used correctly throughout the codebase.

Verification successful

Verification Complete: Optional Parameter Usage Confirmed

The prompt_type parameter in the _apply_models function is optional, with a default value of None. This ensures that its absence in existing function calls does not affect functionality. The parameter is correctly implemented to enhance flexibility without impacting current usage.

  • The _apply_models function is defined with prompt_type=None, confirming its optional nature.
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the usage of the `_apply_models` function with the new `prompt_type` parameter.

# Test: Search for the function usage. Expect: Occurrences with the `prompt_type` parameter.
rg --type python -A 5 '_apply_models'

Length of output: 68


Script:

#!/bin/bash
# Description: Verify the usage of the `_apply_models` function with the new `prompt_type` parameter.

# Test: Search for the function usage. Expect: Occurrences with the `prompt_type` parameter.
rg --type py -A 5 '_apply_models'

Length of output: 6048


Script:

#!/bin/bash
# Description: Extract the definition of the `_apply_models` function to verify the handling of the `prompt_type` parameter.

# Extract the function definition to check for default values or conditional logic.
ast-grep --lang python --pattern $'def _apply_models($_, $_, $_, $_, $_, $_, prompt_type=$_):\n    $$$' 

Length of output: 103


Script:

#!/bin/bash
# Description: Extract the definition of the `_apply_models` function to verify the handling of the `prompt_type` parameter.

# Simplified pattern to match the function definition with the `prompt_type` parameter.
ast-grep --lang python --pattern 'def _apply_models($$$):'

Length of output: 60


Script:

#!/bin/bash
# Description: Search for the definition of the `_apply_models` function using `rg` to verify the handling of the `prompt_type` parameter.

# Use `rg` to find the function definition.
rg --type py -A 10 'def _apply_models'

Length of output: 1375

Tools
Ruff

215-216: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


250-253: Use ternary operator kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {} instead of if-else-block

Replace if-else-block with kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}

(SIM108)

fiftyone/utils/sam2.py (4)

63-75: LGTM!

The SegmentAnything2VideoModelConfig class is correctly configured.


77-225: LGTM!

The SegmentAnything2ImageModel class is well-implemented, with efficient model loading and prediction methods.


36-60: LGTM! Verify validation logic.

The SegmentAnything2ImageModelConfig class is well-structured with appropriate validation for points_mask_index.

Ensure the validation logic is robust across the codebase.

Verification successful

Validation Logic for points_mask_index is Consistent Across Files

The points_mask_index is used consistently in both fiftyone/utils/sam2.py and fiftyone/utils/sam.py. The validation logic ensuring the index is between 0 and 2 is applied in both files. This confirms that the validation logic is robust across the codebase.

  • fiftyone/utils/sam2.py: Lines related to points_mask_index include its definition and validation.
  • fiftyone/utils/sam.py: Similar usage and validation logic are present.
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the usage of `points_mask_index` validation logic.

# Test: Search for `points_mask_index` usage. Expect: Proper validation logic.
rg --type python 'points_mask_index'

Length of output: 67


Script:

#!/bin/bash
# Description: Verify the usage of `points_mask_index` validation logic.

# Test: Search for `points_mask_index` usage. Expect: Proper validation logic.
rg --type py 'points_mask_index'

Length of output: 787


228-529: LGTM! Verify prompt handling logic.

The SegmentAnything2VideoModel class is well-implemented, with robust prompt handling and prediction methods.

Ensure the prompt handling logic is used correctly throughout the codebase.

Comment on lines +250 to +253
if pass_confidence_thresh:
kwargs = {"confidence_thresh": confidence_thresh}
else:
kwargs = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify with a ternary operator.

The if-else block can be simplified using a ternary operator for clarity.

-    if pass_confidence_thresh:
-        kwargs = {"confidence_thresh": confidence_thresh}
-    else:
-        kwargs = {}
+    kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if pass_confidence_thresh:
kwargs = {"confidence_thresh": confidence_thresh}
else:
kwargs = {}
kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
Tools
Ruff

250-253: Use ternary operator kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {} instead of if-else-block

Replace if-else-block with kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}

(SIM108)

@brimoor brimoor changed the title Adding SAM2 into the model zoo Adding SAM2 to the Model Zoo! Aug 15, 2024
@prernadh prernadh force-pushed the prerna/sam-2-integration branch from 33a844e to 9cb6c58 Compare August 15, 2024 16:18
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 33a844e and 9cb6c58.

Files ignored due to path filters (1)
  • fiftyone/zoo/models/manifest-torch.json is excluded by !**/*.json
Files selected for processing (5)
  • docs/scripts/make_model_zoo_docs.py (4 hunks)
  • fiftyone/utils/sam.py (6 hunks)
  • fiftyone/utils/sam2.py (1 hunks)
  • pylintrc (2 hunks)
  • tests/intensive/model_zoo_tests.py (4 hunks)
Files skipped from review as they are similar to previous changes (2)
  • docs/scripts/make_model_zoo_docs.py
  • pylintrc
Additional context used
Ruff
tests/intensive/model_zoo_tests.py

250-253: Use ternary operator kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {} instead of if-else-block

Replace if-else-block with kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}

(SIM108)

fiftyone/utils/sam2.py

538-538: Do not perform function call torch.device in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)

Additional comments not posted (11)
tests/intensive/model_zoo_tests.py (5)

79-86: New test function addition looks good.

The test_sam2_boxes function is a well-structured addition, consistent with the existing test functions.


89-98: New test function addition looks good.

The test_sam2_points function is a well-structured addition, consistent with the existing test functions, and correctly uses prompt_type.


101-106: New test function addition looks good.

The test_sam2_auto function is a well-structured addition, consistent with the existing test functions.


109-115: New test function addition looks good.

The test_sam2_video function is a well-structured addition, consistent with the existing test functions, and correctly uses _apply_video_models.


Line range hint 194-223: Enhancement to _apply_models looks good.

The addition of the prompt_type parameter enhances the function's flexibility to handle different model requirements.

However, verify that prompt_type is correctly used throughout the codebase.

Tools
Ruff

215-216: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


250-253: Use ternary operator kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {} instead of if-else-block

Replace if-else-block with kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}

(SIM108)

fiftyone/utils/sam.py (2)

224-231: Refactor of _forward_pass improves clarity.

Using a dictionary to map prompt types to methods simplifies the control flow and enhances maintainability.


377-385: Enhancement to _to_sam_points improves robustness.

The addition of the keypoint parameter and conditional label extraction enhances the function's capability to handle various input scenarios.

fiftyone/utils/sam2.py (4)

36-60: Configuration for SegmentAnything2ImageModelConfig looks good.

The initialization and validation of parameters like auto_kwargs and points_mask_index are correctly handled.


63-75: Configuration for SegmentAnything2VideoModelConfig looks good.

The class is correctly set up for configuring the SAM2 video model.


77-221: Implementation of SegmentAnything2ImageModel looks good.

The class correctly handles model loading and inference for images, with well-structured methods.


228-450: Implementation of SegmentAnything2VideoModel looks good.

The class correctly handles video processing and inference, with well-structured methods.

Comment on lines +531 to +539
def load_fiftyone_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move torch.device call inside load_fiftyone_video_frames.

The torch.device call should be moved inside the function to avoid potential issues with argument defaults.

-    compute_device=torch.device("cuda"),
+    compute_device=None,
...
+    if compute_device is None:
+        compute_device = torch.device("cuda")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def load_fiftyone_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=torch.device("cuda"),
):
def load_fiftyone_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=None,
):
if compute_device is None:
compute_device = torch.device("cuda")
Tools
Ruff

538-538: Do not perform function call torch.device in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)

@benjaminpkane benjaminpkane merged commit 927a1b9 into develop Aug 15, 2024
13 checks passed
@benjaminpkane benjaminpkane deleted the prerna/sam-2-integration branch August 15, 2024 16:40
benjaminpkane pushed a commit that referenced this pull request Aug 15, 2024
* Adding SAM2 into the model zoo

* Updating model tags

* Adding tests

* Cleaning up _forward_pass

* Updating package requirements

* Updating to work on CPU

* Refactoring code

* Removing 0 masks

* Fixing bug with single mask

* Doc changes + package requirements

* Updating manifest

* tweaks

* Adding initialize model._fields

---------

Co-authored-by: Prerna Dhareshwar <prerna@Prernas-MacBook-Pro.local>
Co-authored-by: brimoor <brimoor@umich.edu>
benjaminpkane pushed a commit that referenced this pull request Aug 15, 2024
* Adding SAM2 into the model zoo

* Updating model tags

* Adding tests

* Cleaning up _forward_pass

* Updating package requirements

* Updating to work on CPU

* Refactoring code

* Removing 0 masks

* Fixing bug with single mask

* Doc changes + package requirements

* Updating manifest

* tweaks

* Adding initialize model._fields

---------

Co-authored-by: Prerna Dhareshwar <prerna@Prernas-MacBook-Pro.local>
Co-authored-by: brimoor <brimoor@umich.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants