Skip to content

Commit 4f1fc4c

Browse files
authored
SAM2 AMG example mIoU, perf numbers and more SAM2 model annotations (#1196)
1 parent 186e578 commit 4f1fc4c

File tree

7 files changed

+345
-83
lines changed

7 files changed

+345
-83
lines changed

examples/sam2_amg_server/README.md

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,101 @@
22
```
33
curl -X POST http://127.0.0.1:5000/upload -F 'image=@/path/to/file.jpg' --output path/to/output.png
44
```
5+
6+
## Example script to collect rles
7+
8+
Start the server
9+
10+
```
11+
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast
12+
```
13+
14+
Collect the rles
15+
16+
```
17+
xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rle -F 'image=@{}' < image_paths > rle_masks
18+
```
19+
20+
## mIoU scores on random subset of sav validation dataset
21+
22+
Experiments run on H100 and with batch size 1
23+
24+
| mode | mIoU | mask count mismatch | avg. ms per request |
25+
| --- |--- | ------------------ | ----------------- |
26+
| baseline | 1.0 | 0 | 786 |
27+
| ao | 1.0 | 0 | 738 |
28+
| fast | 0.95 | 190 | 563 |
29+
| furious | 0 | 1000 | 204 |
30+
31+
mask count mismatch counts the number of requests where the number of masks differ from the baseline.
32+
For example, the baseline may have chosen to segment an image into 18 masks, but the fast variant produces 17 or 19.
33+
We exclude these examples from the mIoU calculation.
34+
35+
### 1. Create a random subset of 1000 images
36+
```
37+
find sav_val -type f > sav_val_image_paths
38+
shuf -n 1000 sav_val_image_paths > sav_val_image_paths_shuf_1000
39+
```
40+
41+
### 2. Use the baseline (https://github.com/facebookresearch/sam2) to generate rles
42+
43+
Make sure you've installed https://github.com/facebookresearch/sam2
44+
45+
Start server
46+
```
47+
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --baseline
48+
```
49+
50+
Generate and save rles (one line per json via `-w "\n"`)
51+
```
52+
$ time xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_baseline_shuf_1000
53+
54+
real 13m6.374s
55+
user 0m3.349s
56+
sys 0m4.137s
57+
```
58+
59+
### 3. Start server with torchao variant of SAM2
60+
Start server
61+
```
62+
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname>
63+
```
64+
65+
Generate and save rles (one line per json via `-w "\n"`)
66+
```
67+
$ time xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_shuf_1000
68+
69+
real 12m18.916s
70+
user 0m3.506s
71+
sys 0m4.350s
72+
```
73+
74+
### 4. Start server with torchao variant of SAM2 and `--fast` optimizations
75+
Start server
76+
```
77+
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast
78+
```
79+
80+
Generate and save rles (one line per json via `-w "\n"`)
81+
```
82+
$ time xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_fast_shuf_1000
83+
84+
real 9m23.912s
85+
user 0m3.271s
86+
sys 0m4.138s
87+
```
88+
89+
### 5. Start server with torchao variant of SAM2 and `--fast` and `--furious` optimizations
90+
Start server
91+
```
92+
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast --furious
93+
```
94+
95+
Generate and save rles (one line per json via `-w "\n"`)
96+
```
97+
$ time xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_fast_furious_shuf_1000
98+
99+
real 3m24.383s
100+
user 0m3.583s
101+
sys 0m4.519s
102+
```
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import fire
2+
import torch
3+
import json
4+
from sam2.utils.amg import rle_to_mask
5+
6+
"""
7+
Script to calculate mIoU given two lists of rles from upload_rle endpoint
8+
of server.
9+
"""
10+
11+
12+
def iou(mask1, mask2):
13+
assert mask1.dim() == 2
14+
assert mask2.dim() == 2
15+
intersection = torch.logical_and(mask1, mask2)
16+
union = torch.logical_or(mask1, mask2)
17+
return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2)))
18+
19+
20+
def main(path0, path1):
21+
fail_count = 0
22+
miou_sum = 0.0
23+
miou_count = 0
24+
with open(path0, 'r') as f0, open(path1, 'r') as f1:
25+
for line0, line1 in zip(f0, f1):
26+
masks0 = json.loads(line0)
27+
masks1 = json.loads(line1)
28+
if masks0.keys() != masks1.keys():
29+
fail_count += 1
30+
continue
31+
for mask0, mask1 in zip(masks0.values(), masks1.values()):
32+
mask0 = torch.from_numpy(rle_to_mask(mask0))
33+
mask1 = torch.from_numpy(rle_to_mask(mask1))
34+
miou_sum += iou(mask0, mask1).item()
35+
miou_count += 1
36+
37+
print(f"fail_count: {fail_count} mIoU: {miou_sum / miou_count}")
38+
39+
40+
if __name__ == "__main__":
41+
fire.Fire(main)

examples/sam2_amg_server/dog_rle.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

examples/sam2_amg_server/server.py

Lines changed: 104 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@
3030

3131
# torch.set_float32_matmul_precision('high')
3232

33+
34+
def iou(mask1, mask2):
35+
assert mask1.dim() == 2
36+
assert mask2.dim() == 2
37+
intersection = torch.logical_and(mask1, mask2)
38+
union = torch.logical_or(mask1, mask2)
39+
return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2)))
40+
41+
3342
def show_anns(anns):
3443
if len(anns) == 0:
3544
return
@@ -49,17 +58,44 @@ def show_anns(anns):
4958
return torch.stack(ms)
5059

5160

52-
def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=False, points_per_batch=64):
61+
def profiler_runner(path, fn, *args, **kwargs):
62+
with torch.profiler.profile(
63+
activities=[torch.profiler.ProfilerActivity.CPU,
64+
torch.profiler.ProfilerActivity.CUDA],
65+
record_shapes=True) as prof:
66+
result = fn(*args, **kwargs)
67+
prof.export_chrome_trace(path)
68+
return result
69+
70+
71+
def main(checkpoint_path,
72+
baseline=False,
73+
fast=False,
74+
furious=False,
75+
unittest=False,
76+
benchmark=False,
77+
profile=None,
78+
verbose=False,
79+
points_per_batch=64,
80+
port=5000,
81+
host="127.0.0.1",
82+
dry=False):
5383
if verbose:
54-
logging.basicConfig(level=logging.INFO)
84+
logging.basicConfig(level=logging.INFO,
85+
format='%(asctime)s - %(levelname)s - %(message)s',
86+
datefmt='%Y-%m-%d %H:%M:%S')
5587
logging.info(f"Running with fast set to {fast} and furious set to {furious}")
88+
logging.info(f"Running with port {port} and host {host}")
5689

57-
if fast:
58-
from torchao._models.sam2.build_sam import build_sam2
59-
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
60-
else:
90+
if baseline:
91+
logging.info(f"Importing sam2 from outside of torchao. If this errors, install https://github.com/facebookresearch/sam2")
6192
from sam2.build_sam import build_sam2
6293
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
94+
from sam2.utils.amg import rle_to_mask
95+
else:
96+
from torchao._models.sam2.build_sam import build_sam2
97+
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
98+
from torchao._models.sam2.utils.amg import rle_to_mask
6399

64100
device = "cuda"
65101
from pathlib import Path
@@ -70,7 +106,7 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa
70106
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
71107

72108
logging.info(f"Using {points_per_batch} points_per_batch")
73-
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch)
109+
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
74110

75111
if furious:
76112
torch.set_float32_matmul_precision('high')
@@ -107,6 +143,37 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa
107143
logging.info(f"Running one iteration to compile.")
108144
masks = mask_generator.generate(example_image)
109145
logging.info(f"First iteration took {time.time() - t}s.")
146+
if unittest:
147+
logging.info(f"Running strict comparison to reference mask")
148+
import json
149+
ref_masks = json.loads(open("dog_rle.json").read())
150+
ret_data = {}
151+
for mask_id in range(len(masks)):
152+
ret_data[f"mask_{mask_id}"] = masks[mask_id]["segmentation"]
153+
v0_areas = []
154+
v1_areas = []
155+
miou_sum = 0.0
156+
miou_count = 0
157+
for k0 in ref_masks:
158+
assert k0 in ret_data, f"Expected {k0} to be in return data"
159+
from torchao._models.sam2.utils.amg import area_from_rle
160+
v0_area = area_from_rle(ref_masks[k0])
161+
v1_area = area_from_rle(ret_data[k0])
162+
v0_areas.append(v0_area)
163+
v1_areas.append(v1_area)
164+
if v0_area != v1_area:
165+
print(f"v0 area {v0_area} doesn't match v1 area {v1_area}")
166+
v0_mask = torch.from_numpy(rle_to_mask(ref_masks[k0]))
167+
v1_mask = torch.from_numpy(rle_to_mask(ret_data[k0]))
168+
if not torch.allclose(v0_mask, v1_mask):
169+
miou_sum += iou(v0_mask, v1_mask)
170+
miou_count += 1
171+
print(f"Masks don't match for key {k0}. IoU is {iou(v0_mask, v1_mask)}")
172+
if miou_count == 0:
173+
print("Masks exactly match reference.")
174+
else:
175+
print(f"mIoU is {miou_sum / miou_count}")
176+
110177
if benchmark:
111178
logging.info(f"Running 3 warumup iterations.")
112179
for _ in range(3):
@@ -121,7 +188,13 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa
121188
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
122189
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
123190
print(f"max_memory_allocated_bytes: {max_memory_allocated_bytes}MiB or {max_memory_allocated_percentage}%")
124-
return
191+
192+
if profile is not None:
193+
print(f"Saving profile under {profile}")
194+
profiler_runner(profile, mask_generator.generate, example_image)
195+
196+
if dry:
197+
return
125198

126199
app = FastAPI()
127200

@@ -133,6 +206,25 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa
133206
allow_methods=["*"],
134207
allow_headers=["*"],
135208
)
209+
210+
@app.post("/upload_rle")
211+
async def upload_rle(image: UploadFile = File(...)):
212+
# Save the uploaded image to a temporary location
213+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{image.filename}")
214+
with open(temp_file.name, "wb") as b:
215+
shutil.copyfileobj(image.file, b)
216+
217+
# Read the image back into memory to send as response
218+
example_image = cv2.imread(temp_file.name)
219+
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
220+
t = time.time()
221+
with torch.backends.cuda.sdp_kernel(enable_cudnn=True):
222+
masks = mask_generator.generate(example_image)
223+
print(f"Took {time.time() - t} to generate a mask for input image.")
224+
ret_data = {}
225+
for mask_id in range(len(masks)):
226+
ret_data[f"mask_{mask_id}"] = masks[mask_id]["segmentation"]
227+
return ret_data
136228

137229
@app.post("/upload")
138230
async def upload_image(image: UploadFile = File(...)):
@@ -143,13 +235,16 @@ async def upload_image(image: UploadFile = File(...)):
143235

144236
# Read the image back into memory to send as response
145237
example_image = cv2.imread(temp_file.name)
238+
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
146239
t = time.time()
147240
with torch.backends.cuda.sdp_kernel(enable_cudnn=True):
148241
masks = mask_generator.generate(example_image)
149242
print(f"Took {time.time() - t} to generate a mask for input image.")
150243
# Save an example
151244
plt.figure(figsize=(example_image.shape[1]/100., example_image.shape[0]/100.), dpi=100)
152245
plt.imshow(example_image)
246+
for i in range(len(masks)):
247+
masks[i]["segmentation"] = rle_to_mask(masks[i]["segmentation"])
153248
show_anns(masks)
154249
plt.axis('off')
155250
plt.tight_layout()
@@ -163,7 +258,7 @@ async def upload_image(image: UploadFile = File(...)):
163258
return StreamingResponse(BytesIO(image_data), media_type="image/png")
164259

165260

166-
uvicorn.run(app, host="127.0.0.1", port=5000, log_level="info")
261+
uvicorn.run(app, host=host, port=port, log_level="info")
167262

168263
if __name__ == "__main__":
169264
fire.Fire(main)

0 commit comments

Comments
 (0)