Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ml-models): update sam from v1 to v2 #489

Merged
merged 7 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,185 changes: 1,059 additions & 1,126 deletions poetry.lock

Large diffs are not rendered by default.

13 changes: 10 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ redis = "^4.5.4"
pyproj = "^3.6.1"
neptune = "^1.8.3"
ultralytics = "8.1.14"
segment-anything = {git = "https://github.com/facebookresearch/segment-anything.git"}
pyzbar = "^0.1.9"
shapelysmooth = "^0.1.1"
torch = "^2.1.2"
torchvision = "^0.16.2"
flask-babel = "^4.0.0"
my-ultralytics-4bands = { git = "https://github.com/itisacloud/ultralytics_multiband_support.git", rev = "ef61cf9870755ae8b21b03253237d15f5856e1a6" }
torch = {version = "^2.4.0+cpu", source = "pytorch-cpu"}
torchvision = {version = "^0.19.0+cpu", source = "pytorch-cpu"}
sam-2 = {git = "https://github.com/facebookresearch/segment-anything-2.git"}
setuptools = "^72.2.0"

[tool.poetry.group.dev.dependencies]
# Versions are fixed to match versions used by pre-commit
Expand All @@ -53,6 +54,12 @@ geopandas = "^0.14.4"
numpy = "^1.26.4"
testcontainers = {extras = ["postgres", "redis"], version = "^4.7.1"}


[[tool.poetry.source]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
priority = "explicit"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Expand Down
2 changes: 1 addition & 1 deletion sketch_map_tool/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,6 @@ def get_config() -> MappingProxyType:
return MappingProxyType(cfg)


def get_config_value(key: str) -> str | int:
def get_config_value(key: str):
config = get_config()
return config[key]
25 changes: 16 additions & 9 deletions sketch_map_tool/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
from io import BytesIO
from uuid import UUID
from zipfile import ZipFile
Expand All @@ -8,7 +7,8 @@
from celery.signals import setup_logging, worker_process_init, worker_process_shutdown
from geojson import FeatureCollection
from numpy.typing import NDArray
from segment_anything import SamPredictor, sam_model_registry
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from ultralytics import YOLO
from ultralytics_4bands import YOLO as YOLO_4

Expand All @@ -26,7 +26,11 @@
post_process,
)
from sketch_map_tool.upload_processing.detect_markings import detect_markings
from sketch_map_tool.upload_processing.ml_models import init_model
from sketch_map_tool.upload_processing.ml_models import (
init_model,
init_sam2,
select_computation_device,
)
from sketch_map_tool.wms import client as wms_client


Expand Down Expand Up @@ -151,13 +155,16 @@ def digitize_sketches(
) -> AsyncResult | FeatureCollection:
# Initialize ml-models. This has to happen inside of celery context.
#
# Prevent usage of CUDA while transforming Tensor objects to numpy arrays
# during marking detection
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Zero shot segment anything model for automatic mask generation
sam_path = init_model(get_config_value("neptune_model_id_sam"))
sam_model = sam_model_registry[get_config_value("model_type_sam")](sam_path)
sam_predictor: SamPredictor = SamPredictor(sam_model) # mask predictor
path = init_sam2()
device = select_computation_device()
sam2_model = build_sam2(
config_file="sam2_hiera_b+.yaml",
ckpt_path=path,
device=device,
)
sam_predictor = SAM2ImagePredictor(sam2_model)

# Custom trained model for object detection (obj) and classification (cls)
# of markings and colors.
if "osm" in layers.values():
Expand Down
27 changes: 15 additions & 12 deletions sketch_map_tool/upload_processing/detect_markings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import cv2
import numpy as np
import torch
from numpy.typing import NDArray
from PIL import Image, ImageEnhance
from segment_anything import SamPredictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
from ultralytics import YOLO
from ultralytics_4bands import YOLO as YOLO_4

Expand All @@ -12,7 +13,7 @@ def detect_markings(
map_frame: NDArray,
yolo_obj: YOLO_4,
yolo_cls: YOLO,
sam_predictor: SamPredictor,
sam_predictor: SAM2ImagePredictor,
) -> list[NDArray]:
"""Run machine learning pipeline and post-processing to detect markings.

Expand Down Expand Up @@ -83,7 +84,7 @@ def apply_ml_pipeline(
difference: Image.Image,
yolo_obj: YOLO_4,
yolo_cls: YOLO,
sam_predictor: SamPredictor,
sam_predictor: SAM2ImagePredictor,
) -> tuple[list[NDArray], NDArray, list]:
"""Apply the entire machine learning pipeline on an image.

Expand Down Expand Up @@ -147,7 +148,7 @@ def apply_yolo_classification(
def apply_sam(
image: Image.Image,
bounding_boxes: NDArray,
sam_predictor: SamPredictor,
sam_predictor: SAM2ImagePredictor,
) -> tuple[list[NDArray], list[np.float32]]:
"""Apply zero-shot SAM (Segment Anything) on an image using bounding boxes.

Expand All @@ -157,26 +158,28 @@ def apply_sam(
Returns:
tuple: List of masks and corresponding scores.
"""
sam_predictor.set_image(np.array(image))
masks = []
scores = []
for bbox in bounding_boxes:
mask, score = mask_from_bbox(bbox, sam_predictor)
masks.append(mask)
scores.append(score)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
sam_predictor.set_image(np.array(image))
masks = []
scores = []
for bbox in bounding_boxes:
mask, score = mask_from_bbox(bbox, sam_predictor)
masks.append(mask)
scores.append(score)
return masks, scores


def mask_from_bbox(
bbox: NDArray,
sam_predictor: SamPredictor,
sam_predictor: SAM2ImagePredictor,
) -> tuple:
"""Generate a mask using SAM (Segment Anything) predictor for a given bounding box.

Returns:
tuple: Mask and corresponding score.
"""
masks, scores, _ = sam_predictor.predict(box=bbox, multimask_output=False)
# TODO: should error be raised if lists has more then one element?
return masks[0], scores[0]


Expand Down
55 changes: 33 additions & 22 deletions sketch_map_tool/upload_processing/ml_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
from pathlib import Path

import neptune
import requests
import torch
from torch._prims_common import DeviceLikeType

from sketch_map_tool.config import get_config_value


def init_model(id: str) -> Path:
"""Initialize model. Download model to data dir if not present."""
# TODO: _check_id(id)
# TODO: check if model is valid/working
raw = Path(get_config_value("data-dir")) / id
path = raw.with_suffix(_get_file_suffix(id))
path = raw.with_suffix(".pt")
if not path.is_file():
logging.info(f"Downloading model {id} from neptune.ai to {path}.")
model = neptune.init_model_version(
Expand All @@ -21,27 +22,37 @@ def init_model(id: str) -> Path:
mode="read-only",
)
model["model"].download(str(path))
return path
return path


def _check_id(id: str):
# TODO:
project = neptune.init_project(
project=get_config_value("neptune_project"),
api_token=get_config_value("neptune_api_token"),
mode="read-only",
)

if not project.exists("models/" + id):
raise ValueError("Invalid model ID: " + id)


def _get_file_suffix(id: str) -> str:
suffixes = {"SAM": ".pth", "OSM": ".pt", "ESRI": ".pt", "CLR": ".pt"}
def init_sam2(id: str = "sam2_hiera_base_plus") -> Path:
raw = Path(get_config_value("data-dir")) / id
path = raw.with_suffix(".pt")
base_url = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/"
url = base_url + id + ".pt"
if not path.is_file():
logging.info(f"Downloading model SAM-2 from fbaipublicfiles.com to {path}.")
response = requests.get(url=url)
with open(path, mode="wb") as file:
file.write(response.content)
return path

for key in suffixes:
if key in id:
return suffixes[key]

raise ValueError(f"Unexpected model ID: {id}")
def select_computation_device() -> DeviceLikeType:
"""Select computation device (cuda, mps, cpu) for SAM-2"""
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
logging.info(f"Using device: {device}")

if device.type == "cuda":
# use bfloat16 for the entire notebook
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# turn on tfloat32 for Ampere GPUs
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
return device

Large diffs are not rendered by default.

Large diffs are not rendered by default.

46 changes: 38 additions & 8 deletions tests/integration/upload_processing/test_detect_markings.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
import random

import numpy as np
import pytest
from PIL import Image, ImageEnhance
from segment_anything import SamPredictor, sam_model_registry
from PIL import Image, ImageDraw, ImageOps
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from ultralytics import YOLO
from ultralytics_4bands import YOLO as YOLO_4

from sketch_map_tool.config import get_config_value
from sketch_map_tool.upload_processing.detect_markings import (
detect_markings,
)
from sketch_map_tool.upload_processing.ml_models import init_model
from sketch_map_tool.upload_processing.ml_models import (
init_model,
init_sam2,
select_computation_device,
)


# Initialize ml-models.
# This usually happens inside the celery task: `digitize_sketches`
@pytest.fixture
def sam_predictor():
"""Zero shot segment anything model"""
sam_path = init_model(get_config_value("neptune_model_id_sam"))
sam_model = sam_model_registry[get_config_value("model_type_sam")](sam_path)
return SamPredictor(sam_model) # mask predictor
path = init_sam2()
device = select_computation_device()
sam2_model = build_sam2(
config_file="sam2_hiera_b+.yaml",
ckpt_path=path,
device=device,
)
return SAM2ImagePredictor(sam2_model)


@pytest.fixture
Expand Down Expand Up @@ -74,6 +86,24 @@ def test_detect_markings(
yolo_cls,
sam_predictor,
)

img = Image.fromarray(map_frame_marked)
for m in markings:
img = Image.fromarray(m)
ImageEnhance.Contrast(img).enhance(10).show()
colors = ["red", "green", "blue", "yellow", "purple", "orange", "pink", "brown"]
m[m == m.max()] = 255
colored_marking = ImageOps.colorize(
Image.fromarray(m).convert("L"), black="black", white=random.choice(colors)
)
img.paste(colored_marking, (0, 0), Image.fromarray(m))
# draw bbox around each marking, derived from the mask m
bbox = (
np.min(np.where(m)[1]),
np.min(np.where(m)[0]),
np.max(np.where(m)[1]),
np.max(np.where(m)[0]),
)

draw = ImageDraw.Draw(img)
draw.rectangle(bbox, outline="red", width=2)

img.show()
Loading