Skip to content

Commit

Permalink
Copy SAM benchmark script for now
Browse files Browse the repository at this point in the history
  • Loading branch information
tokk-nv committed May 20, 2024
1 parent 51dbca8 commit 0df27a5
Showing 1 changed file with 144 additions and 0 deletions.
144 changes: 144 additions & 0 deletions packages/audio/whisper_streaming/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python3
import os
import time
import datetime
import resource
import argparse
import socket
from urllib.parse import urlparse

import numpy as np
import matplotlib.pyplot as plt
import cv2
import PIL.Image
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth")
parser.add_argument('-i', '--images', action='append', nargs='*', help="Paths to images to test")

parser.add_argument('-r', '--runs', type=int, default=2, help="Number of inferencing runs to do (for timing)")
parser.add_argument('-w', '--warmup', type=int, default=1, help='the number of warmup iterations')

parser.add_argument('-s', '--save', type=str, default='', help='CSV file to save benchmarking results to')

args = parser.parse_args()

if not args.images:
args.images = [
"https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg",
"https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/groceries.jpg",
"https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg",
]
else:
args.images = [x[0] for x in args.images]

print(args)

import requests
from tqdm import tqdm

def download_from_url(url, filename=None):

if filename is None:
filename = os.path.basename(urlparse(url).path)

if not os.path.isfile(filename):

response = requests.get(url, stream=True)
total_size_in_bytes= int(response.headers.get('content-length', 0))
block_size = 1024 # 1Kibibyte

print(f"Downloading {filename} :")
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)

with open(filename, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)

progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, download failed!")

return os.path.abspath(filename)

def get_max_rss(): # peak memory usage in MB (max RSS - https://stackoverflow.com/a/7669482)
return (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss) / 1024

def save_anns(cv2_image, anns):

plt.imshow(cv2_image)

if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)

img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
plt.imshow(img)
plt.axis('off')
plt.savefig("sam_benchmark_output.jpg")

avg_encoder=0
avg_latency=0
cv2_image=None
mask=None

CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
FILENAME = os.path.basename(urlparse(args.checkpoint).path)
download_from_url(args.checkpoint, FILENAME)

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)

imagepaths = []
for imageurl in args.images:
imagepaths.append(download_from_url(imageurl))

for run in range(args.runs + args.warmup):

for imagepath in imagepaths:

cv2_image = cv2.imread(imagepath)
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)

time_begin=time.perf_counter()
masks = mask_generator.generate(cv2_image)
time_elapsed=time.perf_counter() - time_begin

print(f"{imagepath}")
print(f" Full pipeline : {time_elapsed:.3f} seconds")

if run >= args.warmup:
avg_latency += time_elapsed

avg_latency /= ( args.runs * len(args.images) )

memory_usage=get_max_rss()

print(f"AVERAGE of {args.runs} runs:")
print(f" latency --- {avg_latency:.3f} sec")
print(f"Memory consumption : {memory_usage:.2f} MB")

save_anns(cv2_image, masks)

if args.save:
if not os.path.isfile(args.save): # csv header
with open(args.save, 'w') as file:
file.write(f"timestamp, hostname, api, checkpoint, latency, memory\n")
with open(args.save, 'a') as file:
file.write(f"{datetime.datetime.now().strftime('%Y%m%d %H:%M:%S')}, {socket.gethostname()}, ")
file.write(f"sam-python, {args.checkpoint}, {avg_latency}, {memory_usage}\n")

0 comments on commit 0df27a5

Please sign in to comment.