Skip to content

Latest commit

 

History

History
252 lines (199 loc) · 10.9 KB

README.md

File metadata and controls

252 lines (199 loc) · 10.9 KB

HQ-SAM 2: Segment Anything in High Quality for Images and Videos

We propose HQ-SAM2 to upgrade SAM2 to higher quality by extending our training strategy in HQ-SAM.

Latest updates

2024/11/17 -- HQ-SAM 2 is released

  • A new suite of improved model checkpoints (denoted as HQ-SAM 2, beta-version) are released. See Model Description for details.

HQ-SAM2 results comparison

Installation

HQ-SAM 2 needs to be installed first before use. The code requires python>=3.10, as well as torch>=2.3.1 and torchvision>=0.18.1. Please follow the instructions here to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using:

git clone https://github.com/SysCV/sam-hq.git
conda create -n sam_hq2 python=3.10 -y
conda activate sam_hq2
cd sam-hq/sam-hq2
pip install -e .

If you are installing on Windows, it's strongly recommended to use Windows Subsystem for Linux (WSL) with Ubuntu.

To use the HQ-SAM 2 predictor and run the example notebooks, jupyter and matplotlib are required and can be installed by:

pip install -e ".[notebooks]"

Note:

  1. It's recommended to create a new Python environment via Anaconda for this installation and install PyTorch 2.3.1 (or higher) via pip following https://pytorch.org/. If you have a PyTorch version lower than 2.3.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using pip.
  2. The step above requires compiling a custom CUDA kernel with the nvcc compiler. If it isn't already available on your machine, please install the CUDA toolkits with a version that matches your PyTorch CUDA version.
  3. If you see a message like Failed to build the SAM 2 CUDA extension during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases).

Please see INSTALL.md for FAQs on potential issues and solutions.

Getting Started

Download Checkpoints

First, we need to download a model checkpoint. All the model checkpoints can be downloaded by running:

cd checkpoints && \
./download_ckpts.sh && \
cd ..

or individually from:

(note that these are the improved checkpoints denoted as SAM 2.1; see Model Description for details.)

Then HQ-SAM 2 can be used in a few lines as follows for image and video prediction.

Image prediction

HQ-SAM 2 has all the capabilities of HQ-SAM on static images, and we provide image prediction APIs that closely resemble SAM for image use cases. The SAM2ImagePredictor class has an easy interface for image prompting.

import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
# Baseline SAM2.1
# checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
# model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

# Ours HQ-SAM 2
checkpoint = "./checkpoints/sam2.1_hq_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hq_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(<your_image>)
    masks, _, _ = predictor.predict(<input_prompts>, multimask_output=False)

Please refer to the examples in python demo/demo_hqsam2.py for details on how to add click or box prompts.

Please refer to the examples in image_predictor_example.ipynb for static image use cases.

Video prediction

For promptable segmentation and tracking in videos, we provide a video predictor with APIs for example to add prompts and propagate masklets throughout a video. SAM 2 supports video inference on multiple objects and uses an inference state to keep track of the interactions in each video.

import torch
from sam2.build_sam import build_sam2_video_predictor
from sam2.build_sam import build_sam2_hq_video_predictor
# Baseline SAM2.1
# checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
# model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
# predictor = build_sam2_video_predictor(model_cfg, checkpoint)

# Ours HQ-SAM 2
checkpoint = "./checkpoints/sam2.1_hq_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hq_hiera_l.yaml"
predictor = build_sam2_hq_video_predictor(model_cfg, checkpoint)

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    state = predictor.init_state(<your_video>)

    # add new prompts and instantly get the output on the same frame
    frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):

    # propagate the prompts to get masklets throughout the video
    for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
        ...

Please refer to the examples in video_predictor_example.ipynb for static image use cases.

Model Description

HQ-SAM 2 checkpoints

The table below shows the zero-shot image segmentation performance of SAM2.1 and HQ-SAM 2 on COCO (AP) using same bounding box detector from Focal-net DINO. The FPS speed of SAM2.1 and HQ-SAM 2 is on par.

Model Size (M) Single Mode (AP) Multi-Mode (AP)
sam2.1_hiera_large
(config, checkpoint)
224.4 50.0 48.3
sam2.1_hq_hiera_large
(config, checkpoint)
224.7 50.9 50.4

The table below shows the zero-shot image segmentation AP performance of Grounded-SAM 2 and Grounded-HQ-SAM 2 on Seginw (Segmentation in the Wild) dataset.

Model Name SAM GroundingDINO Mean AP Airplane-Parts Bottles Brain-Tumor Chicken Cows Electric-Shaver Elephants Fruits Garbage Ginger-Garlic Hand-Metal Hand House-Parts HouseHold-Items Nutterfly-Squireel Phones Poles Puppies Rail Salmon-Fillet Strawberry Tablets Toolkits Trash Watermelon
Grounded SAM2 vit-l swin-b 49.5 38.3 67.1 12.1 80.7 52.8 72.0 78.2 83.3 26.0 45.7 73.7 77.6 8.6 60.1 84.1 34.6 28.8 48.9 14.3 24.2 83.7 29.1 20.1 28.4 66.0
Grounded HQ-SAM2 vit-l swin-b 50.0 38.6 66.8 12.0 81.0 52.8 71.9 77.2 83.3 26.1 45.5 74.8 79.0 8.6 60.1 84.7 34.3 25.5 48.9 14.1 34.1 85.7 29.2 21.5 28.9 66.6

The table below shows the zero-shot video object segmentation performance of SAM2.1 and HQ-SAM 2.

Model Size (M) DAVIS val (J&F) MOSE(J&F)
sam2.1_hiera_large
(config, checkpoint)
224.4 89.8 74.6
sam2.1_hq_hiera_large
(config, checkpoint)
224.7 91.0 74.7

License

The HQ-SAM 2, SAM 2 model checkpoints, SAM 2 demo code (front-end and back-end), and SAM 2 training code are licensed under Apache 2.0, however the Inter Font and Noto Color Emoji used in the SAM 2 demo code are made available under the SIL Open Font License, version 1.1.

Citing HQ-SAM 2

If you find HQ-SAM2 useful in your research or refer to the provided baseline results, please star ⭐ this repository and consider citing 📝:

@inproceedings{sam_hq,
    title={Segment Anything in High Quality},
    author={Ke, Lei and Ye, Mingqiao and Danelljan, Martin and Liu, Yifan and Tai, Yu-Wing and Tang, Chi-Keung and Yu, Fisher},
    booktitle={NeurIPS},
    year={2023}
}