Skip to content

Commit

Permalink
added option to set embeddings for predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Aug 2, 2024
1 parent 0e78a11 commit 6e3e476
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 56 deletions.
286 changes: 230 additions & 56 deletions notebooks/image_predictor_example.ipynb

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions sam2/sam2_image_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,51 @@ def get_image_embedding(self) -> torch.Tensor:
self._features is not None
), "Features must exist if an image has been set."
return self._features["image_embed"]


def get_high_res_features(self) -> torch.Tensor:
"""
Returns the high resolution features for the currently set image.
"""
if not self._is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding."
)
assert (
self._features is not None
), "Features must exist if an image has been set."
return self._features["high_res_feats"]


def set_image_embedding(
self,
image_embedding: torch.Tensor,
img_hw: Tuple,
high_res_features: Optional[List[torch.Tensor]] = None
):
"""
Sets the image embeddings for a previously computed image e.g.
loaded from file.
Arguments:
image_embedding (torch.Tensor): A 1xCxHxW Tensor, where C is the
embedding dimension and (H,W) are the embedding spatial dimension
of SAM (typically C=256, H=W=64).
img_hw (Tuple): Height and Width of the original image.
high_res_features (torch.Tensor or None): Optional high res features
"""
self._is_image_set = True
self._orig_hw = [img_hw]
self._features = {}
self._features["image_embed"] = image_embedding

if high_res_features:
self._features["high_res_feats"] = high_res_features

self._is_batch = False


@property
def device(self) -> torch.device:
Expand Down

0 comments on commit 6e3e476

Please sign in to comment.