forked from CASIA-IVA-Lab/FastSAM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
159 lines (141 loc) · 5.13 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
# Thanks for chenxwh.
import argparse
import cv2
import shutil
import ast
from cog import BasePredictor, Input, Path
from ultralytics import YOLO
from utils.tools import *
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.models = {k: YOLO(f"{k}.pt") for k in ["FastSAM-s", "FastSAM-x"]}
def predict(
self,
input_image: Path = Input(description="Input image"),
model_name: str = Input(
description="choose a model",
choices=["FastSAM-x", "FastSAM-s"],
default="FastSAM-x",
),
iou: float = Input(
description="iou threshold for filtering the annotations", default=0.7
),
text_prompt: str = Input(
description='use text prompt eg: "a black dog"', default=None
),
conf: float = Input(description="object confidence threshold", default=0.25),
retina: bool = Input(
description="draw high-resolution segmentation masks", default=True
),
box_prompt: str = Input(default="[0,0,0,0]", description="[x,y,w,h]"),
point_prompt: str = Input(default="[[0,0]]", description="[[x1,y1],[x2,y2]]"),
point_label: str = Input(default="[0]", description="[1,0] 0:background, 1:foreground"),
withContours: bool = Input(
description="draw the edges of the masks", default=False
),
better_quality: bool = Input(
description="better quality using morphologyEx", default=False
),
) -> Path:
"""Run a single prediction on the model"""
# default params
out_path = "output"
if os.path.exists(out_path):
shutil.rmtree(out_path)
os.makedirs(out_path, exist_ok=True)
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
args = argparse.Namespace(
better_quality=better_quality,
box_prompt=box_prompt,
conf=conf,
device=device,
img_path=str(input_image),
imgsz=1024,
iou=iou,
model_path="FastSAM-x.pt",
output=out_path,
point_label=point_label,
point_prompt=point_prompt,
randomcolor=True,
retina=retina,
text_prompt=text_prompt,
withContours=withContours,
)
args.point_prompt = ast.literal_eval(args.point_prompt)
args.box_prompt = ast.literal_eval(args.box_prompt)
args.point_label = ast.literal_eval(args.point_label)
model = self.models[model_name]
results = model(
str(input_image),
imgsz=args.imgsz,
device=args.device,
retina_masks=args.retina,
iou=args.iou,
conf=args.conf,
max_det=100,
)
if args.box_prompt[2] != 0 and args.box_prompt[3] != 0:
annotations = prompt(results, args, box=True)
annotations = np.array([annotations])
fast_process(
annotations=annotations,
args=args,
mask_random_color=args.randomcolor,
bbox=convert_box_xywh_to_xyxy(args.box_prompt),
)
elif args.text_prompt != None:
results = format_results(results[0], 0)
annotations = prompt(results, args, text=True)
annotations = np.array([annotations])
fast_process(
annotations=annotations, args=args, mask_random_color=args.randomcolor
)
elif args.point_prompt[0] != [0, 0]:
results = format_results(results[0], 0)
annotations = prompt(results, args, point=True)
# list to numpy
annotations = np.array([annotations])
fast_process(
annotations=annotations,
args=args,
mask_random_color=args.randomcolor,
points=args.point_prompt,
)
else:
fast_process(
annotations=results[0].masks.data,
args=args,
mask_random_color=args.randomcolor,
)
out = "/tmp.out.png"
shutil.copy(os.path.join(out_path, os.listdir(out_path)[0]), out)
return Path(out)
def prompt(results, args, box=None, point=None, text=None):
ori_img = cv2.imread(args.img_path)
ori_h = ori_img.shape[0]
ori_w = ori_img.shape[1]
if box:
mask, idx = box_prompt(
results[0].masks.data,
convert_box_xywh_to_xyxy(args.box_prompt),
ori_h,
ori_w,
)
elif point:
mask, idx = point_prompt(
results, args.point_prompt, args.point_label, ori_h, ori_w
)
elif text:
mask, idx = text_prompt(results, args.text_prompt, args.img_path, args.device)
else:
return None
return mask