-
Notifications
You must be signed in to change notification settings - Fork 580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding SAM2 to the Model Zoo! #4671
Conversation
WalkthroughThe recent updates enhance the FiftyOne framework by incorporating advanced image and video segmentation capabilities through the Segment Anything 2 (SAM2) model. Key additions include new classes for processing images and videos, improved input handling, and expanded testing functionalities. Additionally, refinements to existing components streamline the code's maintainability and flexibility, allowing for easier future enhancements. Changes
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Outside diff range, codebase verification and nitpick comments (5)
fiftyone/utils/sam2.py (5)
150-165
: Ensure model is correctly loaded and evaluated.The
_load_model
method correctly loads and evaluates the model on the specified device. Consider logging the model's device and precision for debugging purposes.logger.info(f"Model loaded on device: {self._device}, using half precision: {self.using_half_precision}")
667-676
: Consider logging NaN handling in points.The function handles NaN values in points. Consider logging the number of filtered points for debugging.
logger.debug(f"Filtered {len(points) - len(scaled_points)} NaN points.")
690-698
: Consider logging when no box is found.The function correctly returns
None
when no box is found. Consider logging this case for debugging.if all(arr.size == 0 for arr in pos_indices): logger.debug("No positive indices found in mask.") return None
700-738
: Improve error handling for frame reading.The function should handle potential errors during frame reading and resizing more gracefully.
try: current_frame = video_reader.read() if current_frame is None: raise ValueError("Failed to read frame from video.") resized_frame = cv2.resize(current_frame, (image_size, image_size)) / 255.0 except Exception as e: logger.error(f"Error reading or resizing frame: {e}") raise
167-180
: Ensure safe handling ofNone
values for prompts and classesThe variables
_curr_prompt_type
,_curr_prompts
, and_curr_classes
are used in conditions and loops infiftyone/utils/sam2.py
. If these variables areNone
, the code might raise exceptions or behave unexpectedly. Consider adding checks to handleNone
values safely.
- File:
fiftyone/utils/sam2.py
- Ensure conditions and loops using
_curr_prompt_type
,_curr_prompts
, and_curr_classes
handleNone
values appropriately.Analysis chain
Handle missing samples gracefully.
In
predict_all
, ifsamples
isNone
, ensure that downstream methods can handleNone
values for prompts and classes without errors.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify handling of None samples in predict_all method. # Test: Search for occurrences of predict_all and ensure samples is handled. rg --type python -A 5 'predict_all'Length of output: 66
Script:
#!/bin/bash # Find the implementation of the _predict_all method to verify handling of None values. ast-grep --lang python --pattern $'def _predict_all(self, $_):\n $$$'Length of output: 6523
Script:
#!/bin/bash # Search for occurrences of _curr_prompt_type, _curr_prompts, and _curr_classes to verify handling of None values. rg --type python '_curr_prompt_type|_curr_prompts|_curr_classes'Length of output: 97
Script:
#!/bin/bash # Corrected search for occurrences of _curr_prompt_type, _curr_prompts, and _curr_classes to verify handling of None values. rg --type py '_curr_prompt_type|_curr_prompts|_curr_classes'Length of output: 2643
Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Files ignored due to path filters (1)
fiftyone/zoo/models/manifest-torch.json
is excluded by!**/*.json
Files selected for processing (1)
- fiftyone/utils/sam2.py (1 hunks)
Additional comments not posted (3)
fiftyone/utils/sam2.py (3)
68-71
: Class structure is appropriate.The
SegmentAnything2VideoModelConfig
class is well-structured for potential future extensions.
664-666
: Function is efficient and correct.The
_to_sam_input
function efficiently converts tensors for SAM input.
679-687
: Function is efficient and correct.The
_to_sam_box
function efficiently converts bounding boxes for SAM input.
fiftyone/utils/sam2.py
Outdated
def _forward_pass(self, imgs): | ||
if self._curr_prompt_type == "boxes": | ||
return self._forward_pass_boxes(imgs) | ||
|
||
if self._curr_prompt_type == "points": | ||
return self._forward_pass_points(imgs) | ||
|
||
return self._forward_pass_auto(imgs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor prompt type handling for clarity.
The _forward_pass
method could benefit from a clearer structure by using a dictionary to map prompt types to their respective methods.
forward_methods = {
"boxes": self._forward_pass_boxes,
"points": self._forward_pass_points,
None: self._forward_pass_auto,
}
return forward_methods.get(self._curr_prompt_type, self._forward_pass_auto)(imgs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Files selected for processing (1)
- tests/intensive/model_zoo_tests.py (4 hunks)
Additional context used
Ruff
tests/intensive/model_zoo_tests.py
247-250: Use ternary operator
kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
instead ofif
-else
-blockReplace
if
-else
-block withkwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
(SIM108)
Additional comments not posted (6)
tests/intensive/model_zoo_tests.py (6)
78-85
: New test functiontest_sam2_boxes
looks good.The function is correctly set up to test the SAM2 model with box prompts.
88-97
: New test functiontest_sam2_points
looks good.The function is correctly set up to test the SAM2 model with keypoint prompts.
100-105
: New test functiontest_sam2_auto
looks good.The function is correctly set up to test the SAM2 model with automatic segmentation.
107-113
: New test functiontest_sam2_video
looks good.The function is correctly set up to test the SAM2 model with video prompts.
Line range hint
192-220
:
Modification to_apply_models
is well-integrated.The new
prompt_type
parameter is correctly utilized to handle keypoint prompts.Tools
Ruff
213-214: Use a single
if
statement instead of nestedif
statementsCombine
if
statements usingand
(SIM102)
247-250: Use ternary operator
kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
instead ofif
-else
-blockReplace
if
-else
-block withkwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
(SIM108)
238-268
: Function_apply_video_models
is well-structured.The function correctly handles video datasets and integrates the necessary logic for testing.
Tools
Ruff
247-250: Use ternary operator
kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
instead ofif
-else
-blockReplace
if
-else
-block withkwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
(SIM108)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Files selected for processing (1)
- fiftyone/utils/sam2.py (1 hunks)
Files skipped from review as they are similar to previous changes (1)
- fiftyone/utils/sam2.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Files ignored due to path filters (1)
fiftyone/zoo/models/manifest-torch.json
is excluded by!**/*.json
Files selected for processing (2)
- fiftyone/utils/sam2.py (1 hunks)
- pylintrc (2 hunks)
Additional context used
Ruff
fiftyone/utils/sam2.py
737-737: Do not perform function call
torch.device
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable(B008)
Additional comments not posted (8)
pylintrc (2)
35-35
: Approve addition ofcv2
toextension-pkg-whitelist
.This change allows the use of OpenCV functionalities by permitting the loading of the
cv2
C extension, which is necessary for image and video processing tasks.
293-293
: Approve inclusion ofcv2.*
ingenerated-members
.This addition helps prevent false positives by recognizing dynamic members from the
cv2
module, improving the linting process.fiftyone/utils/sam2.py (6)
61-73
: ApproveSegmentAnything2VideoModelConfig
initialization.The constructor correctly initializes the configuration by calling the superclass constructor. No additional changes are necessary.
692-694
: Approve_to_sam_input
function.The function correctly converts a tensor to a SAM input format, ensuring compatibility with the SAM2 model.
696-705
: Approve_to_sam_points
function.The function effectively processes keypoints, handling NaN values and scaling them appropriately for SAM input. It also manages labels correctly.
708-716
: Approve_to_sam_box
function.The function accurately transforms bounding boxes to SAM format by scaling and adjusting coordinates.
719-727
: Approve_mask_to_box
function.The function efficiently computes bounding boxes from masks, handling cases where no positive indices are found.
772-778
: Approve_load_video_frames_monkey_patch
function.The function correctly implements a monkey patch for video frame loading, ensuring custom functionality is applied.
def __init__(self, d): | ||
d = self.init(d) | ||
super().__init__(d) | ||
|
||
self.auto_kwargs = self.parse_dict(d, "auto_kwargs", default=None) | ||
self.points_mask_index = self.parse_int( | ||
d, "points_mask_index", default=None | ||
) | ||
if self.points_mask_index and not 0 <= self.points_mask_index <= 2: | ||
raise ValueError("mask_index must be 0, 1, or 2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate points_mask_index
more robustly.
The validation for points_mask_index
should ensure it is an integer before checking its range. Consider using isinstance
to check the type.
if self.points_mask_index is not None:
if not isinstance(self.points_mask_index, int) or not 0 <= self.points_mask_index <= 2:
raise ValueError("mask_index must be an integer between 0 and 2")
fiftyone/utils/sam2.py
Outdated
def _forward_pass(self, imgs): | ||
forward_methods = { | ||
"boxes": self._forward_pass_boxes, | ||
"points": self._forward_pass_points, | ||
None: self._forward_pass_auto, | ||
} | ||
return forward_methods.get( | ||
self._curr_prompt_type, self._forward_pass_auto | ||
)(imgs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor prompt type handling for clarity.
The _forward_pass
method could benefit from a clearer structure by using a dictionary to map prompt types to their respective methods.
forward_methods = {
"boxes": self._forward_pass_boxes,
"points": self._forward_pass_points,
None: self._forward_pass_auto,
}
return forward_methods.get(self._curr_prompt_type, self._forward_pass_auto)(imgs)
def _get_field(self): | ||
if "prompt_field" in self.needs_fields: | ||
prompt_field = self.needs_fields["prompt_field"] | ||
else: | ||
prompt_field = next(iter(self.needs_fields.values()), None) | ||
|
||
if not prompt_field.startswith("frames."): | ||
raise ValueError( | ||
"'prompt_field' should be a frame field for segment anything 2 video model" | ||
) | ||
|
||
if prompt_field is None: | ||
raise AttributeError( | ||
"Missing required argument 'prompt_field' for segment anything 2 video model" | ||
) | ||
|
||
prompt_field = prompt_field[len("frames.") :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure prompt_field
is not None
before checking its prefix.
The _get_field
method should first check if prompt_field
is None
before verifying its prefix to prevent potential errors.
if prompt_field is None:
raise AttributeError(
"Missing required argument 'prompt_field' for segment anything 2 video model"
)
if not prompt_field.startswith("frames."):
raise ValueError(
"'prompt_field' should be a frame field for segment anything 2 video model"
)
def load_fiftyone_video_frames( | ||
video_path, | ||
image_size, | ||
offload_video_to_cpu, | ||
img_mean=(0.485, 0.456, 0.406), | ||
img_std=(0.229, 0.224, 0.225), | ||
async_loading_frames=False, | ||
compute_device=torch.device("cuda"), | ||
): | ||
sample, video_reader = video_path | ||
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] | ||
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] | ||
|
||
num_frames = len(sample.frames) | ||
try: | ||
images = torch.zeros( | ||
num_frames, 3, image_size, image_size, dtype=torch.float32 | ||
) | ||
except Exception as e: | ||
raise (e) | ||
for frame_number in range(num_frames): | ||
current_frame = video_reader.read() | ||
resized_frame = ( | ||
cv2.resize(current_frame, (image_size, image_size)) / 255.0 | ||
) | ||
img = torch.from_numpy(resized_frame).permute(2, 0, 1) | ||
images[frame_number] = img | ||
|
||
video_width, video_height = ( | ||
current_frame.shape[1], | ||
current_frame.shape[0], | ||
) | ||
if not offload_video_to_cpu: | ||
images = images.to(compute_device) | ||
img_mean = img_mean.to(compute_device) | ||
img_std = img_std.to(compute_device) | ||
|
||
images -= img_mean | ||
images /= img_std | ||
return images, video_height, video_width |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid function call in argument defaults.
The torch.device
call in the compute_device
argument default should be moved inside the function to avoid potential issues.
def load_fiftyone_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
compute_device=None,
):
if compute_device is None:
compute_device = torch.device("cuda")
...
Tools
Ruff
737-737: Do not perform function call
torch.device
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable(B008)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Outside diff range, codebase verification and nitpick comments (1)
fiftyone/utils/sam2.py (1)
50-59
: Validatepoints_mask_index
more robustly.Ensure
points_mask_index
is an integer before checking its range. Consider usingisinstance
to check the type.- if self.points_mask_index and not 0 <= self.points_mask_index <= 2: + if self.points_mask_index is not None: + if not isinstance(self.points_mask_index, int) or not 0 <= self.points_mask_index <= 2: + raise ValueError("mask_index must be an integer between 0 and 2")
Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Files selected for processing (2)
- fiftyone/utils/sam.py (5 hunks)
- fiftyone/utils/sam2.py (1 hunks)
Additional context used
Ruff
fiftyone/utils/sam2.py
531-531: Do not perform function call
torch.device
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable(B008)
Additional comments not posted (8)
fiftyone/utils/sam.py (6)
224-231
: Refactor improves maintainability.The use of a dictionary to map prompt types to their handling methods simplifies the control flow and enhances maintainability.
233-234
: Encapsulation promotes reuse.The introduction of
_load_predictor
encapsulates the creation of aSamPredictor
instance, promoting code reuse and reducing redundancy.
237-237
: Consistent predictor instantiation.Using
_load_predictor
for predictor instantiation reduces code duplication and maintains consistency.
287-287
: Consistent predictor instantiation and enhanced point handling.Using
_load_predictor
maintains consistency, and changes to_to_sam_points
enhance its capability to handle different input scenarios.
348-350
: Streamlined generator instantiation.The introduction of
_load_auto_generator
ensures consistent object creation and initialization.
377-386
: Enhanced flexibility in point processing.The modification to
_to_sam_points
to accept akeypoint
parameter and conditionally extract labels enhances the function's flexibility and robustness.fiftyone/utils/sam2.py (2)
76-218
: Well-structured image model implementation.The
SegmentAnything2ImageModel
class is well-organized, with clear methods for loading models and processing prompts, consistent with its intended functionality.
225-443
: Comprehensive video model implementation with robust validation.The
SegmentAnything2VideoModel
class is comprehensive, with well-defined methods for handling video prompts and processing. The_get_field
method includes robust validation checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Files selected for processing (1)
- fiftyone/utils/sam.py (6 hunks)
Additional comments not posted (6)
fiftyone/utils/sam.py (6)
224-231
: Great use of dictionary for method mapping!The refactor using a dictionary to map prompt types to their corresponding methods improves readability and maintainability.
233-234
: Encapsulation of predictor creation is a good move!The introduction of
_load_predictor
centralizes the creation of theSamPredictor
, promoting code reuse and reducing redundancy.
348-350
: Centralized generator instantiation is beneficial!The introduction of
_load_auto_generator
centralizes the creation of theSamAutomaticMaskGenerator
, promoting consistency and modularity.
377-385
: Enhanced flexibility in handling keypoints!The modification to
_to_sam_points
to accept akeypoint
parameter and handle "sam_labels" enhances its flexibility and robustness.
287-287
: Consistent handling of empty keypoints!The handling of cases with no keypoints by returning empty tensors is consistent with the approach in
_forward_pass_boxes
and is efficient.Ensure that this behavior is correctly integrated and tested across the codebase.
Verification successful
Consistent and efficient handling of empty keypoints verified!
The
_forward_pass_points
function in bothfiftyone/utils/sam.py
andfiftyone/utils/sam2.py
handles empty keypoints appropriately. Infiftyone/utils/sam.py
, it returns empty tensors, and infiftyone/utils/sam2.py
, it skips processing. This confirms the review comment's observation about efficient handling.
fiftyone/utils/sam.py
: Returns empty tensors when no keypoints are present.fiftyone/utils/sam2.py
: Skips processing when no keypoints are present.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify handling of empty keypoints in `_forward_pass_points`. # Test: Search for the function usage. Expect: Proper handling of empty keypoints. rg --type python -A 5 $'_forward_pass_points'Length of output: 75
Script:
#!/bin/bash # Description: Verify handling of empty keypoints in `_forward_pass_points`. # Test: Search for the function usage. Expect: Proper handling of empty keypoints. rg --type py -A 5 $'_forward_pass_points'Length of output: 1462
Script:
#!/bin/bash # Description: Verify handling of empty keypoints in `_forward_pass_points` in both implementations. # Search for the implementation details of `_forward_pass_points` in `fiftyone/utils/sam.py` ast-grep --lang python --pattern $'def _forward_pass_points($_):\n $$$' fiftyone/utils/sam.py # Search for the implementation details of `_forward_pass_points` in `fiftyone/utils/sam2.py` ast-grep --lang python --pattern $'def _forward_pass_points($_):\n $$$' fiftyone/utils/sam2.pyLength of output: 194
Script:
#!/bin/bash # Description: Verify handling of empty keypoints in `_forward_pass_points` using `rg`. # Search for the `_forward_pass_points` function definition in `fiftyone/utils/sam.py` with context rg --type py -A 20 'def _forward_pass_points' fiftyone/utils/sam.py # Search for the `_forward_pass_points` function definition in `fiftyone/utils/sam2.py` with context rg --type py -A 20 'def _forward_pass_points' fiftyone/utils/sam2.pyLength of output: 1881
237-237
: Efficient handling of empty detections!The handling of cases with no detections by returning empty tensors is efficient and prevents unnecessary computations.
Ensure that this behavior is correctly integrated and tested across the codebase.
Verification successful
Efficient handling of empty detections confirmed!
The
_forward_pass_boxes
function in bothfiftyone/utils/sam.py
andfiftyone/utils/sam2.py
efficiently handles cases with no detections by returning empty tensors, thus preventing unnecessary computations. This behavior is correctly integrated into the codebase.
fiftyone/utils/sam.py
fiftyone/utils/sam2.py
Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify handling of empty detections in `_forward_pass_boxes`. # Test: Search for the function usage. Expect: Proper handling of empty detections. rg --type python -A 5 $'_forward_pass_boxes'Length of output: 74
Script:
#!/bin/bash # Description: Verify handling of empty detections in `_forward_pass_boxes`. # Test: Search for the function usage. Expect: Proper handling of empty detections. rg --type py -A 5 $'_forward_pass_boxes'Length of output: 1758
Script:
#!/bin/bash # Extract the full implementation of `_forward_pass_boxes` from `fiftyone/utils/sam.py` ast-grep --lang python --pattern $'def _forward_pass_boxes($_) { $$$ }' fiftyone/utils/sam.py # Extract the full implementation of `_forward_pass_boxes` from `fiftyone/utils/sam2.py` ast-grep --lang python --pattern $'def _forward_pass_boxes($_) { $$$ }' fiftyone/utils/sam2.pyLength of output: 1106
Script:
#!/bin/bash # Correct the pattern and extract the full implementation of `_forward_pass_boxes` from `fiftyone/utils/sam.py` ast-grep --lang python --pattern 'def _forward_pass_boxes($_):' fiftyone/utils/sam.py # Correct the pattern and extract the full implementation of `_forward_pass_boxes` from `fiftyone/utils/sam2.py` ast-grep --lang python --pattern 'def _forward_pass_boxes($_):' fiftyone/utils/sam2.pyLength of output: 176
Script:
#!/bin/bash # Use `rg` to search for the function definition and include more lines of context to capture the full implementation. rg --type py -A 20 'def _forward_pass_boxes' fiftyone/utils/sam.py rg --type py -A 20 'def _forward_pass_boxes' fiftyone/utils/sam2.pyLength of output: 2805
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Files selected for processing (1)
- fiftyone/utils/sam2.py (1 hunks)
Additional context used
Ruff
fiftyone/utils/sam2.py
533-533: Do not perform function call
torch.device
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable(B008)
Additional comments not posted (6)
fiftyone/utils/sam2.py (6)
50-59
: Enhancepoints_mask_index
validation.The validation for
points_mask_index
should ensure it is an integer before checking its range. Consider usingisinstance
to check the type.if self.points_mask_index is not None: if not isinstance(self.points_mask_index, int) or not 0 <= self.points_mask_index <= 2: raise ValueError("mask_index must be an integer between 0 and 2")
62-74
: LGTM!The
SegmentAnything2VideoModelConfig
class is well-defined and straightforward.
149-167
: LGTM!The model loading and predictor initialization methods are implemented correctly.
314-330
: Ensureprompt_field
is notNone
before checking its prefix.The
_get_field
method should first check ifprompt_field
isNone
before verifying its prefix to prevent potential errors.if prompt_field is None: raise AttributeError( "Missing required argument 'prompt_field' for segment anything 2 video model" ) if not prompt_field.startswith("frames."): raise ValueError( "'prompt_field' should be a frame field for segment anything 2 video model" )
526-534
: Avoid function call in argument defaults.Move the
torch.device
call inside the function to avoid potential issues.- compute_device=torch.device("cuda"), + compute_device=None, ... + if compute_device is None: + compute_device = torch.device("cuda")Tools
Ruff
533-533: Do not perform function call
torch.device
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable(B008)
568-574
: LGTM!The
_load_video_frames_monkey_patch
function is correctly implemented.
def _forward_pass(self, video_reader, sample): | ||
if self._curr_prompt_type == "boxes": | ||
return self._forward_pass_boxes(video_reader, sample) | ||
elif self._curr_prompt_type == "points": | ||
return self._forward_pass_points(video_reader, sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle unsupported prompt types in _forward_pass
.
Consider adding a default case to handle unsupported prompt types gracefully.
else:
raise ValueError(f"Unsupported prompt type: {self._curr_prompt_type}")
def _forward_pass_boxes(self, imgs): | ||
sam2_predictor = self._load_predictor() | ||
self._output_processor = fout.InstanceSegmenterOutputProcessor( | ||
self._curr_classes | ||
) | ||
outputs = [] | ||
for img, detections in zip(imgs, self._curr_prompts): | ||
## If no detections, return empty tensors instead of running SAM | ||
if detections is None or len(detections.detections) == 0: | ||
h, w = img.shape[1], img.shape[2] | ||
outputs.append( | ||
{ | ||
"boxes": torch.tensor([[]]), | ||
"labels": torch.empty([0, 4]), | ||
"masks": torch.empty([0, 1, h, w]), | ||
} | ||
) | ||
continue | ||
inp = fosam._to_sam_input(img) | ||
sam2_predictor.set_image(inp) | ||
h, w = img.size(1), img.size(2) | ||
|
||
boxes = [d.bounding_box for d in detections.detections] | ||
sam_boxes = np.array( | ||
[fosam._to_sam_box(box, w, h) for box in boxes] | ||
) | ||
input_boxes = torch.tensor(sam_boxes, device=sam2_predictor.device) | ||
|
||
labels = torch.tensor( | ||
[ | ||
self._curr_classes.index(d.label) | ||
for d in detections.detections | ||
], | ||
device=sam2_predictor.device, | ||
) | ||
|
||
masks, _, _ = sam2_predictor.predict( | ||
point_coords=None, | ||
point_labels=None, | ||
box=sam_boxes[None, :], | ||
multimask_output=False, | ||
) | ||
if masks.ndim == 3: | ||
masks = np.expand_dims(masks, axis=0) | ||
outputs.append( | ||
{ | ||
"boxes": input_boxes, | ||
"labels": labels, | ||
"masks": torch.tensor(masks, device=sam2_predictor.device), | ||
} | ||
) | ||
|
||
return outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimize empty tensor handling in _forward_pass_boxes
.
The handling of empty detections can be optimized by using torch.empty
with the correct device directly.
- h, w = img.shape[1], img.shape[2]
+ device = sam2_predictor.device
outputs.append(
{
"boxes": torch.tensor([[]], device=device),
"labels": torch.empty([0, 4], device=device),
"masks": torch.empty([0, 1, h, w], device=device),
}
)
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _forward_pass_boxes(self, imgs): | |
sam2_predictor = self._load_predictor() | |
self._output_processor = fout.InstanceSegmenterOutputProcessor( | |
self._curr_classes | |
) | |
outputs = [] | |
for img, detections in zip(imgs, self._curr_prompts): | |
## If no detections, return empty tensors instead of running SAM | |
if detections is None or len(detections.detections) == 0: | |
h, w = img.shape[1], img.shape[2] | |
outputs.append( | |
{ | |
"boxes": torch.tensor([[]]), | |
"labels": torch.empty([0, 4]), | |
"masks": torch.empty([0, 1, h, w]), | |
} | |
) | |
continue | |
inp = fosam._to_sam_input(img) | |
sam2_predictor.set_image(inp) | |
h, w = img.size(1), img.size(2) | |
boxes = [d.bounding_box for d in detections.detections] | |
sam_boxes = np.array( | |
[fosam._to_sam_box(box, w, h) for box in boxes] | |
) | |
input_boxes = torch.tensor(sam_boxes, device=sam2_predictor.device) | |
labels = torch.tensor( | |
[ | |
self._curr_classes.index(d.label) | |
for d in detections.detections | |
], | |
device=sam2_predictor.device, | |
) | |
masks, _, _ = sam2_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=sam_boxes[None, :], | |
multimask_output=False, | |
) | |
if masks.ndim == 3: | |
masks = np.expand_dims(masks, axis=0) | |
outputs.append( | |
{ | |
"boxes": input_boxes, | |
"labels": labels, | |
"masks": torch.tensor(masks, device=sam2_predictor.device), | |
} | |
) | |
return outputs | |
def _forward_pass_boxes(self, imgs): | |
sam2_predictor = self._load_predictor() | |
self._output_processor = fout.InstanceSegmenterOutputProcessor( | |
self._curr_classes | |
) | |
outputs = [] | |
for img, detections in zip(imgs, self._curr_prompts): | |
## If no detections, return empty tensors instead of running SAM | |
if detections is None or len(detections.detections) == 0: | |
device = sam2_predictor.device | |
h, w = img.shape[1], img.shape[2] | |
outputs.append( | |
{ | |
"boxes": torch.tensor([[]], device=device), | |
"labels": torch.empty([0, 4], device=device), | |
"masks": torch.empty([0, 1, h, w], device=device), | |
} | |
) | |
continue | |
inp = fosam._to_sam_input(img) | |
sam2_predictor.set_image(inp) | |
h, w = img.size(1), img.size(2) | |
boxes = [d.bounding_box for d in detections.detections] | |
sam_boxes = np.array( | |
[fosam._to_sam_box(box, w, h) for box in boxes] | |
) | |
input_boxes = torch.tensor(sam_boxes, device=sam2_predictor.device) | |
labels = torch.tensor( | |
[ | |
self._curr_classes.index(d.label) | |
for d in detections.detections | |
], | |
device=sam2_predictor.device, | |
) | |
masks, _, _ = sam2_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=sam_boxes[None, :], | |
multimask_output=False, | |
) | |
if masks.ndim == 3: | |
masks = np.expand_dims(masks, axis=0) | |
outputs.append( | |
{ | |
"boxes": input_boxes, | |
"labels": labels, | |
"masks": torch.tensor(masks, device=sam2_predictor.device), | |
} | |
) | |
return outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍 Works on my end :)
fiftyone/utils/sam.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like these changes. Unification and improved readability for the win!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Files ignored due to path filters (1)
fiftyone/zoo/models/manifest-torch.json
is excluded by!**/*.json
Files selected for processing (2)
- docs/scripts/make_model_zoo_docs.py (3 hunks)
- fiftyone/utils/sam2.py (1 hunks)
Additional context used
Ruff
fiftyone/utils/sam2.py
532-532: Do not perform function call
torch.device
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable(B008)
Additional comments not posted (8)
docs/scripts/make_model_zoo_docs.py (2)
101-109
: LGTM! Handling of video datasets is correct.The logic for loading and processing video datasets is well-implemented and meets the specified requirements.
Line range hint
120-143
: LGTM! Enhanced handling of 'segment-anything' tags.The changes effectively differentiate between video and non-video datasets, applying models with appropriate parameters.
fiftyone/utils/sam2.py (6)
50-59
: Validatepoints_mask_index
more robustly.Ensure
points_mask_index
is an integer before checking its range. Consider usingisinstance
to check the type.
62-69
: LGTM! Configuration class for video model is appropriate.The class serves its purpose as a configuration holder for the SAM2 video model.
168-220
: Optimize empty tensor handling in_forward_pass_boxes
.The handling of empty detections can be optimized by using
torch.empty
with the correct device directly.
313-329
: Ensureprompt_field
is notNone
before checking its prefix.The
_get_field
method should first check ifprompt_field
isNone
before verifying its prefix to prevent potential errors.
364-368
: Handle unsupported prompt types in_forward_pass
.Consider adding a default case to handle unsupported prompt types gracefully.
525-533
: Avoid function call in argument defaults.Move the
torch.device
call inside the function to avoid potential issues.Tools
Ruff
532-532: Do not perform function call
torch.device
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable(B008)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, great work! 🥇
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Files ignored due to path filters (1)
fiftyone/zoo/models/manifest-torch.json
is excluded by!**/*.json
Files selected for processing (3)
- docs/scripts/make_model_zoo_docs.py (4 hunks)
- fiftyone/utils/sam2.py (1 hunks)
- tests/intensive/model_zoo_tests.py (4 hunks)
Files skipped from review as they are similar to previous changes (1)
- docs/scripts/make_model_zoo_docs.py
Additional context used
Ruff
tests/intensive/model_zoo_tests.py
250-253: Use ternary operator
kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
instead ofif
-else
-blockReplace
if
-else
-block withkwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
(SIM108)
fiftyone/utils/sam2.py
538-538: Do not perform function call
torch.device
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable(B008)
Additional comments not posted (9)
tests/intensive/model_zoo_tests.py (5)
79-86
: LGTM!The
test_sam2_boxes
function correctly sets up the test for the SAM2 model with box prompts.
89-98
: LGTM!The
test_sam2_points
function is correctly configured to test the SAM2 model with keypoints.
101-106
: LGTM!The
test_sam2_auto
function is correctly set up for automatic segmentation testing with the SAM2 model.
109-115
: LGTM!The
test_sam2_video
function is correctly configured to test video segmentation with the SAM2 model.
Line range hint
194-223
: LGTM! Verify the function usage in the codebase.The addition of the
prompt_type
parameter enhances the flexibility of the_apply_models
function.Ensure that the parameter is used correctly throughout the codebase.
Verification successful
Verification Complete: Optional Parameter Usage Confirmed
The
prompt_type
parameter in the_apply_models
function is optional, with a default value ofNone
. This ensures that its absence in existing function calls does not affect functionality. The parameter is correctly implemented to enhance flexibility without impacting current usage.
- The
_apply_models
function is defined withprompt_type=None
, confirming its optional nature.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the usage of the `_apply_models` function with the new `prompt_type` parameter. # Test: Search for the function usage. Expect: Occurrences with the `prompt_type` parameter. rg --type python -A 5 '_apply_models'Length of output: 68
Script:
#!/bin/bash # Description: Verify the usage of the `_apply_models` function with the new `prompt_type` parameter. # Test: Search for the function usage. Expect: Occurrences with the `prompt_type` parameter. rg --type py -A 5 '_apply_models'Length of output: 6048
Script:
#!/bin/bash # Description: Extract the definition of the `_apply_models` function to verify the handling of the `prompt_type` parameter. # Extract the function definition to check for default values or conditional logic. ast-grep --lang python --pattern $'def _apply_models($_, $_, $_, $_, $_, $_, prompt_type=$_):\n $$$'Length of output: 103
Script:
#!/bin/bash # Description: Extract the definition of the `_apply_models` function to verify the handling of the `prompt_type` parameter. # Simplified pattern to match the function definition with the `prompt_type` parameter. ast-grep --lang python --pattern 'def _apply_models($$$):'Length of output: 60
Script:
#!/bin/bash # Description: Search for the definition of the `_apply_models` function using `rg` to verify the handling of the `prompt_type` parameter. # Use `rg` to find the function definition. rg --type py -A 10 'def _apply_models'Length of output: 1375
Tools
Ruff
215-216: Use a single
if
statement instead of nestedif
statementsCombine
if
statements usingand
(SIM102)
250-253: Use ternary operator
kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
instead ofif
-else
-blockReplace
if
-else
-block withkwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
(SIM108)
fiftyone/utils/sam2.py (4)
63-75
: LGTM!The
SegmentAnything2VideoModelConfig
class is correctly configured.
77-225
: LGTM!The
SegmentAnything2ImageModel
class is well-implemented, with efficient model loading and prediction methods.
36-60
: LGTM! Verify validation logic.The
SegmentAnything2ImageModelConfig
class is well-structured with appropriate validation forpoints_mask_index
.Ensure the validation logic is robust across the codebase.
Verification successful
Validation Logic for
points_mask_index
is Consistent Across FilesThe
points_mask_index
is used consistently in bothfiftyone/utils/sam2.py
andfiftyone/utils/sam.py
. The validation logic ensuring the index is between 0 and 2 is applied in both files. This confirms that the validation logic is robust across the codebase.
fiftyone/utils/sam2.py
: Lines related topoints_mask_index
include its definition and validation.fiftyone/utils/sam.py
: Similar usage and validation logic are present.Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the usage of `points_mask_index` validation logic. # Test: Search for `points_mask_index` usage. Expect: Proper validation logic. rg --type python 'points_mask_index'Length of output: 67
Script:
#!/bin/bash # Description: Verify the usage of `points_mask_index` validation logic. # Test: Search for `points_mask_index` usage. Expect: Proper validation logic. rg --type py 'points_mask_index'Length of output: 787
228-529
: LGTM! Verify prompt handling logic.The
SegmentAnything2VideoModel
class is well-implemented, with robust prompt handling and prediction methods.Ensure the prompt handling logic is used correctly throughout the codebase.
if pass_confidence_thresh: | ||
kwargs = {"confidence_thresh": confidence_thresh} | ||
else: | ||
kwargs = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplify with a ternary operator.
The if-else
block can be simplified using a ternary operator for clarity.
- if pass_confidence_thresh:
- kwargs = {"confidence_thresh": confidence_thresh}
- else:
- kwargs = {}
+ kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if pass_confidence_thresh: | |
kwargs = {"confidence_thresh": confidence_thresh} | |
else: | |
kwargs = {} | |
kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {} |
Tools
Ruff
250-253: Use ternary operator
kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
instead ofif
-else
-blockReplace
if
-else
-block withkwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
(SIM108)
33a844e
to
9cb6c58
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Files ignored due to path filters (1)
fiftyone/zoo/models/manifest-torch.json
is excluded by!**/*.json
Files selected for processing (5)
- docs/scripts/make_model_zoo_docs.py (4 hunks)
- fiftyone/utils/sam.py (6 hunks)
- fiftyone/utils/sam2.py (1 hunks)
- pylintrc (2 hunks)
- tests/intensive/model_zoo_tests.py (4 hunks)
Files skipped from review as they are similar to previous changes (2)
- docs/scripts/make_model_zoo_docs.py
- pylintrc
Additional context used
Ruff
tests/intensive/model_zoo_tests.py
250-253: Use ternary operator
kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
instead ofif
-else
-blockReplace
if
-else
-block withkwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
(SIM108)
fiftyone/utils/sam2.py
538-538: Do not perform function call
torch.device
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable(B008)
Additional comments not posted (11)
tests/intensive/model_zoo_tests.py (5)
79-86
: New test function addition looks good.The
test_sam2_boxes
function is a well-structured addition, consistent with the existing test functions.
89-98
: New test function addition looks good.The
test_sam2_points
function is a well-structured addition, consistent with the existing test functions, and correctly usesprompt_type
.
101-106
: New test function addition looks good.The
test_sam2_auto
function is a well-structured addition, consistent with the existing test functions.
109-115
: New test function addition looks good.The
test_sam2_video
function is a well-structured addition, consistent with the existing test functions, and correctly uses_apply_video_models
.
Line range hint
194-223
: Enhancement to_apply_models
looks good.The addition of the
prompt_type
parameter enhances the function's flexibility to handle different model requirements.However, verify that
prompt_type
is correctly used throughout the codebase.Tools
Ruff
215-216: Use a single
if
statement instead of nestedif
statementsCombine
if
statements usingand
(SIM102)
250-253: Use ternary operator
kwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
instead ofif
-else
-blockReplace
if
-else
-block withkwargs = {"confidence_thresh": confidence_thresh} if pass_confidence_thresh else {}
(SIM108)
fiftyone/utils/sam.py (2)
224-231
: Refactor of_forward_pass
improves clarity.Using a dictionary to map prompt types to methods simplifies the control flow and enhances maintainability.
377-385
: Enhancement to_to_sam_points
improves robustness.The addition of the
keypoint
parameter and conditional label extraction enhances the function's capability to handle various input scenarios.fiftyone/utils/sam2.py (4)
36-60
: Configuration forSegmentAnything2ImageModelConfig
looks good.The initialization and validation of parameters like
auto_kwargs
andpoints_mask_index
are correctly handled.
63-75
: Configuration forSegmentAnything2VideoModelConfig
looks good.The class is correctly set up for configuring the SAM2 video model.
77-221
: Implementation ofSegmentAnything2ImageModel
looks good.The class correctly handles model loading and inference for images, with well-structured methods.
228-450
: Implementation ofSegmentAnything2VideoModel
looks good.The class correctly handles video processing and inference, with well-structured methods.
def load_fiftyone_video_frames( | ||
video_path, | ||
image_size, | ||
offload_video_to_cpu, | ||
img_mean=(0.485, 0.456, 0.406), | ||
img_std=(0.229, 0.224, 0.225), | ||
async_loading_frames=False, | ||
compute_device=torch.device("cuda"), | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move torch.device
call inside load_fiftyone_video_frames
.
The torch.device
call should be moved inside the function to avoid potential issues with argument defaults.
- compute_device=torch.device("cuda"),
+ compute_device=None,
...
+ if compute_device is None:
+ compute_device = torch.device("cuda")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def load_fiftyone_video_frames( | |
video_path, | |
image_size, | |
offload_video_to_cpu, | |
img_mean=(0.485, 0.456, 0.406), | |
img_std=(0.229, 0.224, 0.225), | |
async_loading_frames=False, | |
compute_device=torch.device("cuda"), | |
): | |
def load_fiftyone_video_frames( | |
video_path, | |
image_size, | |
offload_video_to_cpu, | |
img_mean=(0.485, 0.456, 0.406), | |
img_std=(0.229, 0.224, 0.225), | |
async_loading_frames=False, | |
compute_device=None, | |
): | |
if compute_device is None: | |
compute_device = torch.device("cuda") |
Tools
Ruff
538-538: Do not perform function call
torch.device
in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable(B008)
* Adding SAM2 into the model zoo * Updating model tags * Adding tests * Cleaning up _forward_pass * Updating package requirements * Updating to work on CPU * Refactoring code * Removing 0 masks * Fixing bug with single mask * Doc changes + package requirements * Updating manifest * tweaks * Adding initialize model._fields --------- Co-authored-by: Prerna Dhareshwar <prerna@Prernas-MacBook-Pro.local> Co-authored-by: brimoor <brimoor@umich.edu>
* Adding SAM2 into the model zoo * Updating model tags * Adding tests * Cleaning up _forward_pass * Updating package requirements * Updating to work on CPU * Refactoring code * Removing 0 masks * Fixing bug with single mask * Doc changes + package requirements * Updating manifest * tweaks * Adding initialize model._fields --------- Co-authored-by: Prerna Dhareshwar <prerna@Prernas-MacBook-Pro.local> Co-authored-by: brimoor <brimoor@umich.edu>
What changes are proposed in this pull request?
Adding Segment Anything 2 to the Fiftyone model zoo.
How is this patch tested? If it is not, please explain why.
Tested manually with different configurations -
Release Notes
Is this a user-facing change that should be mentioned in the release notes?
notes for FiftyOne users.
Added SAM2 into the Fiftyone model zoo with inference support for both images and videos.
What areas of FiftyOne does this PR affect?
fiftyone
Python library changesBox prompt for Images
Keypoint prompt for Images
Automatic segmentation for Images
Prompting for Videos
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Chores