Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAM2 for segmenting a 2 hour video? #264

Open
aendrs opened this issue Aug 26, 2024 · 16 comments
Open

SAM2 for segmenting a 2 hour video? #264

aendrs opened this issue Aug 26, 2024 · 16 comments

Comments

@aendrs
Copy link

aendrs commented Aug 26, 2024

In your opinion would it be possible to use SAM2 to segment a 2 hour video (720p, 60fps) with a 4090 GPU, avoiding of course the errors due to lack of memory?
What could be the best strategy to succeed in doing so?

@kevinpl07
Copy link

You would have to do it in chunks of 10s clips. You could take the mask of the last frame per chunk and use it as input for the next chunk. That would take a while but be fully automated.

@heyoeyo
Copy link

heyoeyo commented Aug 26, 2024

The largest model uses <2GB of VRAM for videos, so a 4090 should have no issues. The main problem would be the likelihood of the segmentation failing at some point combined with the time it takes (using the large model at 60fps, I'd guess it would be 3-4 hours on a 4090), since that's a long time to have to sit there and correct the outputs. It might make sense to first run the tiny model at 512px resolution (see issue #257), which should take <1hr, to give some idea of where the tracking struggles.

As for memory build up in the demo, the original code is setup for interactive use and won't work as-is. You'd have to clear the cached results as the video runs (see #196) and probably also avoid loading all the frames in at the start... I guess by a combination of using the async_loading_frames option on init_state and disabling (i.e. comment out) the storage of async loaded frames.

Alternatively, there are existing code bases that are aimed at this, for example #90, maybe PR #46, maybe #73 and I also have a script for it.

@aendrs
Copy link
Author

aendrs commented Aug 28, 2024

Thanks, I'll take a look at the links you provided. Could you explain to me what async_loading_frames do?

@heyoeyo
Copy link

heyoeyo commented Aug 28, 2024

Could you explain to me what async_loading_frames do?

By default, the video predictor loads & preprocesses every single frame of your video before doing any segmentation. If you run the examples, you'll see this show up as a progress bar when you begin tracking:

frame loading (JPEG): 100%|███████| 200/200

Only after this finishes does the SAM model actually start doing anything. The model results show up as a different progress bar:

propagate in video:  22%|████     | 45/200

When you set async_loading_frames=True, the frame loading and SAM model run at the same time.

In theory the async loading is a far more practical choice, because it avoids loading everything into memory. Weirdly, the loader runs in it's own thread and actually does store everything in memory, which sort of defeats the purpose. But you can fix it by commenting out the storage line like I mentioned before, and it's also probably worth commenting out the threading lines to stop the loader from trying to get ahead of the model. Those changes should allow you to run any length of video, but the predictor still caches results as it runs (around ~1MB per frame) which will eventually consume all your memory for longer videos (but that can be fixed with the other changes I mentioned).

Here's a minimal video example that prints out VRAM usage. You can try running it with a different async setting and with/without the threading/storage lines commented out to see the differences:

from time import perf_counter
import torch
import numpy as np
from sam2.build_sam import build_sam2_video_predictor

video_folder_path = "notebooks/videos/bedroom"
cfg, ckpt = "sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"
device = "cuda" # or "cpu"
predictor = build_sam2_video_predictor(cfg, ckpt, device)
inference_state = predictor.init_state(
    video_path=video_folder_path,
    async_loading_frames=False
)

predictor.add_new_points(
    inference_state=inference_state,
    frame_idx=0,
    obj_id=1,
    points=np.array([[210, 350]], dtype=np.float32),
    labels=np.array([1], np.int32),
)

tprev = -1
for result in predictor.propagate_in_video(inference_state):
    # Do nothing with results, just report VRAM use
    if  (perf_counter() > tprev + 1.0) and torch.cuda.is_available():
        free_bytes, total_bytes = torch.cuda.mem_get_info()
        print("VRAM:", (total_bytes - free_bytes) // 1_000_000, "MB")
        tprev = perf_counter()
    pass

When I run this, the worst case scenario is the original code with async=True which uses >2.5GB VRAM and keeps ballooning as it runs. The best case is also with async=True but with threading & storage commented out, which ends up needing around 1.1GB (but will still grow slowly without clearing cached results).

@JamesMcCullochDickens
Copy link

@heyoeyo

Any ideas for how to clear the cache in addition to doing async=True but with threading & storage commented out?

Thanks.

@heyoeyo
Copy link

heyoeyo commented Mar 23, 2025

how to clear the cache

Yes, there's a link to some code for this in the post above (e.g. issue 196).

@Grpab
Copy link

Grpab commented Mar 31, 2025

The largest model uses <2GB of VRAM for videos, so a 4090 should have no issues. The main problem would be the likelihood of the segmentation failing at some point combined with the time it takes (using the large model at 60fps, I'd guess it would be 3-4 hours on a 4090), since that's a long time to have to sit there and correct the outputs. It might make sense to first run the tiny model at 512px resolution (see issue #257), which should take <1hr, to give some idea of where the tracking struggles.

As for memory build up in the demo, the original code is setup for interactive use and won't work as-is. You'd have to clear the cached results as the video runs (see #196) and probably also avoid loading all the frames in at the start... I guess by a combination of using the async_loading_frames option on init_state and disabling (i.e. comment out) the storage of async loaded frames.

Alternatively, there are existing code bases that are aimed at this, for example #90, maybe PR #46, maybe #73 and I also have a script for it.

I tried the following methods: offload_video_to_cpu=True, async_video_to_cpu=True, but it didn't solve the problem of insufficient memory in samurai, which led to the inability to process long videos. Do you have any other better methods?

@JamesMcCullochDickens
Copy link

JamesMcCullochDickens commented Mar 31, 2025

@Grpab I do have a solution, I'm surprised how unclear a lot of this has been.

First I add a lazy video loader as opposed to loading all frames at once, this can be done easily modifying the code
in init_state to return an iterator with getitem rather than all frames at once.

in utils/msic.py I add:

class LazyVideoFrameLoader:
    def __init__(
        self,
        video_path,
        image_size,
        offload_video_to_cpu,
        img_mean=(0.485, 0.456, 0.406),
        img_std=(0.229, 0.224, 0.225),
        compute_device=torch.device("cuda"),
    ):
        import decord
        decord.bridge.set_bridge("torch")
        self.vr = decord.VideoReader(video_path, width=image_size, height=image_size)
        self.image_size = image_size
        self.offload_video_to_cpu = offload_video_to_cpu
        self.img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
        self.img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
        self.compute_device = compute_device

        # Move mean/std to device if needed
        if not offload_video_to_cpu:
            self.img_mean = self.img_mean.to(compute_device)
            self.img_std = self.img_std.to(compute_device)

        # Store height/width info
        first_frame = self.vr[0]
        self.video_height, self.video_width = first_frame.shape[:2]

    def __len__(self):
        return len(self.vr)

    def __getitem__(self, frame_idx):
        frame = self.vr[frame_idx]
        frame = frame.permute(2, 0, 1).float() / 255.0  # C, H, W

        if not self.offload_video_to_cpu:
            frame = frame.to(self.compute_device)

        frame = (frame - self.img_mean) / self.img_std
        return frame```


def load_video_frames_lazy(
    video_path,
    image_size,
    offload_video_to_cpu,
    img_mean=(0.485, 0.456, 0.406),
    img_std=(0.229, 0.224, 0.225),
    compute_device=torch.device("cuda"),
):
    import decord
    vr = decord.VideoReader(video_path)
    video_height, video_width = vr[0].shape[:2]
    """
    Returns a lazy video loader that normalizes and resizes frames on-the-fly.
    """
    return LazyVideoFrameLoader(
        video_path,
        image_size,
        offload_video_to_cpu,
        img_mean,
        img_std,
        compute_device,
    ), video_height, video_width

where in init_state I have:

def init_state(
     self,
     video_path,
     offload_video_to_cpu=False,
     offload_state_to_cpu=False,
     async_loading_frames=False,
     lazy_im_loading=True
 ):
     """Initialize an inference state."""
     compute_device = self.device  # device of the model
     if not lazy_im_loading:
         images, video_height, video_width = load_video_frames(
             video_path=video_path,
             image_size=self.image_size,
             offload_video_to_cpu=offload_video_to_cpu,
             async_loading_frames=async_loading_frames,
             compute_device=compute_device,
         )
     else:
         images, video_height, video_width = load_video_frames_lazy(video_path=video_path,
             image_size=self.image_size,
             offload_video_to_cpu=offload_video_to_cpu,
             compute_device=compute_device)```
    ... # more code, just truncating here

Then I delete older conditional and non conditional frame embeddings after track_step in sam2_video_predictor.py
in _runs_single_frame_inference

def _run_single_frame_inference(
            self,
            inference_state,
            output_dict,
            frame_idx,
            batch_size,
            is_init_cond_frame,
            point_inputs,
            mask_inputs,
            reverse,
            run_mem_encoder,
            prev_sam_mask_logits=None,
    ):
        """Run tracking on a single frame based on current inputs and previous memory."""
        # Retrieve correct image features
        (
            _,
            _,
            current_vision_feats,
            current_vision_pos_embeds,
            feat_sizes,
        ) = self._get_image_feature(inference_state, frame_idx, batch_size)

        # point and mask should not appear as input simultaneously on the same frame
        assert point_inputs is None or mask_inputs is None
        current_out = self.track_step(
            frame_idx=frame_idx,
            is_init_cond_frame=is_init_cond_frame,
            current_vision_feats=current_vision_feats,
            current_vision_pos_embeds=current_vision_pos_embeds,
            feat_sizes=feat_sizes,
            point_inputs=point_inputs,
            mask_inputs=mask_inputs,
            output_dict=output_dict,
            num_frames=inference_state["num_frames"],
            track_in_reverse=reverse,
            run_mem_encoder=run_mem_encoder,
            prev_sam_mask_logits=prev_sam_mask_logits,
        )

        # optionally offload the output to CPU memory to save GPU space
        storage_device = inference_state["storage_device"]
        maskmem_features = current_out["maskmem_features"]
        if maskmem_features is not None:
            maskmem_features = maskmem_features.to(torch.bfloat16)
            maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
        pred_masks_gpu = current_out["pred_masks"]
        # potentially fill holes in the predicted masks
        if self.fill_hole_area > 0:
            pred_masks_gpu = fill_holes_in_mask_scores(
                pred_masks_gpu, self.fill_hole_area
            )
        pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
        # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
        maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
        # object pointer is a small tensor, so we always keep it on GPU memory for fast access
        obj_ptr = current_out["obj_ptr"]
        object_score_logits = current_out["object_score_logits"]

        # make a compact version of this frame's output to reduce the state size
        compact_current_out = {
            "maskmem_features": maskmem_features,
            "maskmem_pos_enc": maskmem_pos_enc,
            "pred_masks": pred_masks,
            "obj_ptr": obj_ptr,
            "object_score_logits": object_score_logits,
        }

        max_memory_frames = 5
        non_cond_outputs = output_dict["non_cond_frame_outputs"]
        if len(non_cond_outputs) > max_memory_frames:
            oldest_idx = min(non_cond_outputs.keys())
            del non_cond_outputs[oldest_idx]

        # TODO haven't tested this
        """
        cond_outputs = output_dict["cond_frame_outputs"]
        if len(cond_outputs) > max_memory_frames:
           oldest_idx = min(cond_outputs.keys())
           del cond_outputs[oldest_idx]
        """

I haven't tested the above as much, but that's the general idea.

and most importantly I modify build_sam2_video_predictor as follows:

def build_sam2_video_predictor(
    config_file,
    ckpt_path=None,
    device="cuda",
    mode="eval",
    hydra_overrides_extra=[],
    apply_postprocessing=True,
    vos_optimized=False,
    **kwargs,
):
    hydra_overrides = [
        "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
    ]
    if vos_optimized:
        hydra_overrides = [
            "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
            "++model.compile_image_encoder=True",  # Let sam2_base handle this
            "++model.max_cond_frames_in_attn=5", # IMPORTANT FIX
            "++model.num_maskmem=7",
        ]

    if apply_postprocessing:
        hydra_overrides_extra = hydra_overrides_extra.copy()
        hydra_overrides_extra += [
            # dynamically fall back to multi-mask if the single mask is not stable
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
            # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
            "++model.binarize_mask_from_pts_for_mem_enc=true",
            # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
            "++model.fill_hole_area=8",
            "++model.max_cond_frames_in_attn=5",  # IMPORTANT FIX
            "++model.num_maskmem=7",
        ]
    hydra_overrides.extend(hydra_overrides_extra)

    # Read config and init model
    cfg = compose(config_name=config_file, overrides=hydra_overrides)
    OmegaConf.resolve(cfg)
    model = instantiate(cfg.model, _recursive_=True)
    _load_checkpoint(model, ckpt_path)
    model = model.to(device)
    if mode == "eval":
        model.eval()
    return model

Again set these added parameters to your situation, but for me, this worked like a charm. No OOM and great segmentation results on a very long video.

@Grpab
Copy link

Grpab commented Mar 31, 2025

况设置这些添加的参数,但对我来说,这非常有效。没有 OOM,并且

Thank you very much for your reply, I will try your method. Do you know how to modify it in samurai to process video in real time?

@JamesMcCullochDickens
Copy link

@Grpab I have no idea lol.

@Grpab
Copy link

Grpab commented Mar 31, 2025

@Grpab I have no idea lol.

Then I delete older conditional and non conditional frame embeddings in track_step in sam2_video_predictor.py
in _runs_single_frame_inference

Can you elaborate on this? “Add this at the end of the function (use the best_memory_frames for your setup), here below only for non_cond_frame_outputs, but you get the idea... ”

@Grpab
Copy link

Grpab commented Mar 31, 2025

@Grpab I have no idea lol.

Thank you very much for your help!

@JamesMcCullochDickens
Copy link

JamesMcCullochDickens commented Mar 31, 2025

As far as I am aware,

init_state takes in video path in samv2 ,

You are using Samurai rather than Samv2, so I am not sure how to fix that, you will need to patch the compatibilities, and thats the extent of my help, best of luck. They may be using the legacy version of the sam2_video_predictor, but I don't think it should be too hard, just study the code and make the necessary adjustments I think.

@Grpab
Copy link

Grpab commented Mar 31, 2025

据我所知,

init_state 接收 samv2 中的视频路径,

您使用的是 Samurai 而不是 Samv2,所以我不确定如何修复它,您需要修补兼容性,这就是我所能提供的帮助,祝您好运。他们可能使用的是旧版 sam2_video_predictor,但我认为这应该不会太难,只要研究代码并进行必要的调整即可。

OK, thank you

@Grpab
Copy link

Grpab commented Mar 31, 2025

不是一次加载所有帧,这可以通过

I used your method and it was very effective. It improved from 20s to one minute, but if it exceeds one minute, it will still report insufficient memory.

@GoldenFishes
Copy link

不是一次加载所有帧,这可以通过

我使用了您的方法,它非常有效。它从 20 秒提高到 1 分钟,但如果超过 1 分钟,它仍然会报告内存不足。

This workflow involves continuously releasing old frames to maintain constant memory and GPU memory overhead in infinite-length video processing.
https://github.com/motern88/Det-SAM2
https://arxiv.org/abs/2411.18977

However, it requires timely retrieval of processed frame results for streaming output. If these frame results are released before being accessed, the inference computation will have been wasted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants