generated from autodistill/autodistill-base-model-template
-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathgrounded_sam_2.py
95 lines (76 loc) · 3.09 KB
/
grounded_sam_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
from dataclasses import dataclass
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from typing import Any
import numpy as np
import supervision as sv
import torch
from autodistill.detection import CaptionOntology, DetectionBaseModel
from autodistill.helpers import load_image
from autodistill_florence_2 import Florence2
from autodistill_grounded_sam_2.helpers import (
combine_detections,
load_grounding_dino,
load_SAM,
)
HOME = os.path.expanduser("~")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SamPredictor = load_SAM()
SUPPORTED_GROUNDING_MODELS = ["Florence 2", "Grounding DINO"]
@dataclass
class GroundedSAM2(DetectionBaseModel):
ontology: CaptionOntology
box_threshold: float
text_threshold: float
def __init__(
self,
ontology: CaptionOntology,
model="Florence 2",
grounding_dino_box_threshold=0.35,
grounding_dino_text_threshold=0.25,
):
if model not in SUPPORTED_GROUNDING_MODELS:
raise ValueError(
f"Grounding model {model} is not supported. Supported models are {SUPPORTED_GROUNDING_MODELS}"
)
self.ontology = ontology
if model == "Florence 2":
self.florence_2_predictor = Florence2(ontology=ontology)
elif model == "Grounding DINO":
self.grounding_dino_model = load_grounding_dino()
self.sam_2_predictor = SamPredictor
self.model = model
self.grounding_dino_box_threshold = grounding_dino_box_threshold
self.grounding_dino_text_threshold = grounding_dino_text_threshold
def predict(self, input: Any) -> sv.Detections:
image = load_image(input, return_format="cv2")
if self.model == "Florence 2":
detections = self.florence_2_predictor.predict(image)
elif self.model == "Grounding DINO":
# GroundingDINO predictions
detections_list = []
for i, description in enumerate(self.ontology.prompts()):
# detect objects
detections = self.grounding_dino_model.predict_with_classes(
image=image,
classes=[description],
box_threshold=self.grounding_dino_box_threshold,
text_threshold=self.grounding_dino_text_threshold,
)
detections_list.append(detections)
detections = combine_detections(
detections_list, overwrite_class_ids=range(len(detections_list))
)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
self.sam_2_predictor.set_image(image)
result_masks = []
for box in detections.xyxy:
masks, scores, _ = self.sam_2_predictor.predict(
box=box, multimask_output=False
)
index = np.argmax(scores)
masks = masks.astype(bool)
result_masks.append(masks[index])
detections.mask = np.array(result_masks)
return detections