Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Project structure and pre-commit changes #72

Merged
merged 5 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-yaml
- id: check-docstring-first
- id: check-executables-have-shebangs
- id: check-toml
- id: check-case-conflict
- id: check-added-large-files
args: ['--maxkb=2048']
exclude: ^logo/
- id: detect-private-key
- id: forbid-new-submodules
- id: pretty-format-json
args: ['--autofix', '--no-sort-keys', '--indent=4']
- id: end-of-file-fixer
- id: mixed-line-ending
- repo: https://github.com/asottile/pyupgrade
rev: v3.4.0
hooks:
- id: pyupgrade
args:
- --py3-plus
- --keep-runtime-typing
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.270
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: isort (python)
- id: isort
name: isort (cython)
types: [cython]
- id: isort
name: isort (pyi)
types: [pyi]
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/PyCQA/bandit
rev: '1.7.5'
hooks:
- id: bandit
args: ["-c", "pyproject.toml"]
additional_dependencies: ["bandit[toml]"]
- repo: https://github.com/PyCQA/autoflake
rev: v2.1.1
hooks:
- id: autoflake

ci:
autofix_commit_msg: "dev(pre-commit):🎨 Auto format from pre-commit.com hooks"
autoupdate_commit_msg: "dev(pre-commit):⬆ pre-commit autoupdate"
1 change: 0 additions & 1 deletion MANIFEST.in

This file was deleted.

6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor
results = SegAutoMaskPredictor().image_predict(
source="image.jpg",
model_type="vit_l", # vit_l, vit_h, vit_b
points_per_side=16,
points_per_side=16,
points_per_batch=64,
min_area=0,
output_path="output.jpg",
Expand All @@ -40,7 +40,7 @@ results = SegAutoMaskPredictor().image_predict(
results = SegAutoMaskPredictor().video_predict(
source="video.mp4",
model_type="vit_l", # vit_l, vit_h, vit_b
points_per_side=16,
points_per_side=16,
points_per_batch=64,
min_area=1000,
output_path="output.mp4",
Expand Down Expand Up @@ -121,7 +121,7 @@ image = falai_automask_image(
points_per_side=16,
points_per_batch=32,
min_area=0,
)
)
image.show() # Show image
image.save("output.jpg") # Save image

Expand Down
12 changes: 7 additions & 5 deletions metaseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from metaseg.generator.automatic_mask_generator import SamAutomaticMaskGenerator
from metaseg.generator.build_sam import build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, sam_model_registry
from metaseg.generator.predictor import SamPredictor
from metaseg.mask_predictor import SegAutoMaskPredictor, SegManualMaskPredictor
from .falai_demo import automask_image as automask_image
from .falai_demo import falai_automask_image as falai_automask_image
from .falai_demo import falai_manuelmask_image as falai_manuelmask_image
from .falai_demo import manuelmask_image as manuelmask_image
from .sahi_predict import SahiAutoSegmentation as SahiAutoSegmentation
from .sahi_predict import sahi_sliced_predict as sahi_sliced_predict

__version__ = "0.7.4"
__version__ = "0.7.5"
13 changes: 10 additions & 3 deletions metaseg/falai_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
try:
from fal_serverless import isolated
except ImportError:
raise ImportError("Please install FalAI library using 'pip install fal_serverless'.")
raise ImportError(
"Please install FalAI library using 'pip install fal_serverless'."
)


@isolated(requirements=["metaseg"], keep_alive=1800, machine_type="GPU-T4")
def automask_image(data, model_type="vit_b", points_per_side=16, points_per_batch=32, min_area=0):
def automask_image(
data, model_type="vit_b", points_per_side=16, points_per_batch=32, min_area=0
):
image_path, output_path = load_server_image(data)
SegAutoMaskPredictor().image_predict(
source=image_path,
Expand All @@ -29,6 +33,7 @@ def automask_image(data, model_type="vit_b", points_per_side=16, points_per_batc

return result


@isolated(requirements=["metaseg"], keep_alive=1800, machine_type="GPU-T4")
def manuelmask_image(
data,
Expand Down Expand Up @@ -60,7 +65,9 @@ def manuelmask_image(
return result


def falai_automask_image(image_path, model_type="vit_b", points_per_side=16, points_per_batch=32, min_area=0):
def falai_automask_image(
image_path, model_type="vit_b", points_per_side=16, points_per_batch=32, min_area=0
):
with open(image_path, "rb") as f:
data = f.read()

Expand Down
9 changes: 9 additions & 0 deletions metaseg/generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .automatic_mask_generator import (
SamAutomaticMaskGenerator as SamAutomaticMaskGenerator,
)
from .build_sam import build_sam as build_sam
from .build_sam import build_sam_vit_b as build_sam_vit_b
from .build_sam import build_sam_vit_h as build_sam_vit_h
from .build_sam import build_sam_vit_l as build_sam_vit_l
from .build_sam import sam_model_registry as sam_model_registry
from .predictor import SamPredictor as SamPredictor
28 changes: 21 additions & 7 deletions metaseg/generator/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:

# Encode masks
if self.output_mode == "coco_rle":
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
mask_data["segmentations"] = [
coco_encode_rle(rle) for rle in mask_data["rles"]
]
elif self.output_mode == "binary_mask":
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
else:
Expand All @@ -196,7 +198,9 @@ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:

def _generate_masks(self, image: np.ndarray) -> MaskData:
orig_size = image.shape[:2]
crop_boxes, layer_idxs = generate_crop_boxes(orig_size, self.crop_n_layers, self.crop_overlap_ratio)
crop_boxes, layer_idxs = generate_crop_boxes(
orig_size, self.crop_n_layers, self.crop_overlap_ratio
)

# Iterate over image crops
data = MaskData()
Expand Down Expand Up @@ -240,7 +244,9 @@ def _process_crop(
# Generate masks for this crop in batches
data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
batch_data = self._process_batch(
points, cropped_im_size, crop_box, orig_size
)
data.cat(batch_data)
del batch_data
self.predictor.reset_image()
Expand Down Expand Up @@ -273,7 +279,9 @@ def _process_batch(
# Run model on this batch
transformed_points = self.predictor.transform.apply_coords(points, im_size)
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
in_labels = torch.ones(
in_points.shape[0], dtype=torch.int, device=in_points.device
)
masks, iou_preds, _ = self.predictor.predict_torch(
in_points[:, None, :],
in_labels[:, None],
Expand All @@ -296,7 +304,9 @@ def _process_batch(

# Calculate stability score
data["stability_score"] = calculate_stability_score(
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
data["masks"],
self.predictor.model.mask_threshold,
self.stability_score_offset,
)
if self.stability_score_thresh > 0.0:
keep_mask = data["stability_score"] >= self.stability_score_thresh
Expand All @@ -307,7 +317,9 @@ def _process_batch(
data["boxes"] = batched_mask_to_box(data["masks"])

# Filter boxes that touch crop boundaries
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
keep_mask = ~is_box_near_crop_edge(
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
)
if not torch.all(keep_mask):
data.filter(keep_mask)

Expand All @@ -319,7 +331,9 @@ def _process_batch(
return data

@staticmethod
def postprocess_small_regions(mask_data: MaskData, min_area: int, nms_thresh: float) -> MaskData:
def postprocess_small_regions(
mask_data: MaskData, min_area: int, nms_thresh: float
) -> MaskData:
"""
Removes small disconnected regions and holes in masks, then reruns
box NMS to remove any new duplicates.
Expand Down
15 changes: 14 additions & 1 deletion metaseg/generator/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@

import torch

from metaseg.modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
from metaseg.modeling import (
ImageEncoderViT,
MaskDecoder,
PromptEncoder,
Sam,
TwoWayTransformer,
)


def build_sam_vit_h(checkpoint=None):
Expand Down Expand Up @@ -44,6 +50,13 @@ def build_sam_vit_b(checkpoint=None):
)


build_sam_vit_h = {
"default": build_sam,
"vit_h": build_sam,
"vit_l": build_sam_vit_l,
"vit_b": build_sam_vit_b,
}

sam_model_registry = {
"default": build_sam,
"vit_h": build_sam,
Expand Down
45 changes: 34 additions & 11 deletions metaseg/generator/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def set_image(
# Transform the image to the form expected by the model
input_image = self.transform.apply_image(image)
input_image_torch = torch.as_tensor(input_image, device=self.device)
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
None, :, :, :
]

self.set_torch_image(input_image_torch, image.shape[:2])

Expand All @@ -79,7 +81,10 @@ def set_torch_image(
len(transformed_image.shape) == 4
and transformed_image.shape[1] == 3
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
), (
f"set_torch_image input must be BCHW with long side "
f"{self.model.image_encoder.img_size}."
)
self.reset_image()

self.original_size = original_image_size
Expand Down Expand Up @@ -130,22 +135,32 @@ def predict(
a subsequent iteration as mask input.
"""
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
raise RuntimeError(
"An image must be set with .set_image(...) before mask prediction."
)

# Transform input prompts
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
if point_coords is not None:
assert point_labels is not None, "point_labels must be supplied if point_coords is supplied."
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
point_coords = self.transform.apply_coords(point_coords, self.original_size)
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
coords_torch = torch.as_tensor(
point_coords, dtype=torch.float, device=self.device
)
labels_torch = torch.as_tensor(
point_labels, dtype=torch.int, device=self.device
)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if box is not None:
box = self.transform.apply_boxes(box, self.original_size)
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
box_torch = box_torch[None, :]
if mask_input is not None:
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
mask_input_torch = torch.as_tensor(
mask_input, dtype=torch.float, device=self.device
)
mask_input_torch = mask_input_torch[None, :, :, :]

masks, iou_predictions, low_res_masks = self.predict_torch(
Expand Down Expand Up @@ -208,7 +223,9 @@ def predict_torch(
a subsequent iteration as mask input.
"""
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
raise RuntimeError(
"An image must be set with .set_image(...) before mask prediction."
)

if point_coords is not None:
points = (point_coords, point_labels)
Expand All @@ -232,7 +249,9 @@ def predict_torch(
)

# Upscale the masks to the original image resolution
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
masks = self.model.postprocess_masks(
low_res_masks, self.input_size, self.original_size
)

if not return_logits:
masks = masks > self.model.mask_threshold
Expand All @@ -246,8 +265,12 @@ def get_image_embedding(self) -> torch.Tensor:
the embedding spatial dimension of SAM (typically C=256, H=W=64).
"""
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) to generate an embedding.")
assert self.features is not None, "Features must exist if an image has been set."
raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding."
)
assert (
self.features is not None
), "Features must exist if an image has been set."
return self.features

@property
Expand Down
Loading