3030
3131# torch.set_float32_matmul_precision('high')
3232
33+
34+ def iou (mask1 , mask2 ):
35+ assert mask1 .dim () == 2
36+ assert mask2 .dim () == 2
37+ intersection = torch .logical_and (mask1 , mask2 )
38+ union = torch .logical_or (mask1 , mask2 )
39+ return (intersection .sum (dim = (- 1 , - 2 )) / union .sum (dim = (- 1 , - 2 )))
40+
41+
3342def show_anns (anns ):
3443 if len (anns ) == 0 :
3544 return
@@ -49,17 +58,44 @@ def show_anns(anns):
4958 return torch .stack (ms )
5059
5160
52- def main (checkpoint_path , fast = False , furious = False , benchmark = False , verbose = False , points_per_batch = 64 ):
61+ def profiler_runner (path , fn , * args , ** kwargs ):
62+ with torch .profiler .profile (
63+ activities = [torch .profiler .ProfilerActivity .CPU ,
64+ torch .profiler .ProfilerActivity .CUDA ],
65+ record_shapes = True ) as prof :
66+ result = fn (* args , ** kwargs )
67+ prof .export_chrome_trace (path )
68+ return result
69+
70+
71+ def main (checkpoint_path ,
72+ baseline = False ,
73+ fast = False ,
74+ furious = False ,
75+ unittest = False ,
76+ benchmark = False ,
77+ profile = None ,
78+ verbose = False ,
79+ points_per_batch = 64 ,
80+ port = 5000 ,
81+ host = "127.0.0.1" ,
82+ dry = False ):
5383 if verbose :
54- logging .basicConfig (level = logging .INFO )
84+ logging .basicConfig (level = logging .INFO ,
85+ format = '%(asctime)s - %(levelname)s - %(message)s' ,
86+ datefmt = '%Y-%m-%d %H:%M:%S' )
5587 logging .info (f"Running with fast set to { fast } and furious set to { furious } " )
88+ logging .info (f"Running with port { port } and host { host } " )
5689
57- if fast :
58- from torchao ._models .sam2 .build_sam import build_sam2
59- from torchao ._models .sam2 .automatic_mask_generator import SAM2AutomaticMaskGenerator
60- else :
90+ if baseline :
91+ logging .info (f"Importing sam2 from outside of torchao. If this errors, install https://github.com/facebookresearch/sam2" )
6192 from sam2 .build_sam import build_sam2
6293 from sam2 .automatic_mask_generator import SAM2AutomaticMaskGenerator
94+ from sam2 .utils .amg import rle_to_mask
95+ else :
96+ from torchao ._models .sam2 .build_sam import build_sam2
97+ from torchao ._models .sam2 .automatic_mask_generator import SAM2AutomaticMaskGenerator
98+ from torchao ._models .sam2 .utils .amg import rle_to_mask
6399
64100 device = "cuda"
65101 from pathlib import Path
@@ -70,7 +106,7 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa
70106 sam2 = build_sam2 (model_cfg , sam2_checkpoint , device = device , apply_postprocessing = False )
71107
72108 logging .info (f"Using { points_per_batch } points_per_batch" )
73- mask_generator = SAM2AutomaticMaskGenerator (sam2 , points_per_batch = points_per_batch )
109+ mask_generator = SAM2AutomaticMaskGenerator (sam2 , points_per_batch = points_per_batch , output_mode = "uncompressed_rle" )
74110
75111 if furious :
76112 torch .set_float32_matmul_precision ('high' )
@@ -107,6 +143,37 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa
107143 logging .info (f"Running one iteration to compile." )
108144 masks = mask_generator .generate (example_image )
109145 logging .info (f"First iteration took { time .time () - t } s." )
146+ if unittest :
147+ logging .info (f"Running strict comparison to reference mask" )
148+ import json
149+ ref_masks = json .loads (open ("dog_rle.json" ).read ())
150+ ret_data = {}
151+ for mask_id in range (len (masks )):
152+ ret_data [f"mask_{ mask_id } " ] = masks [mask_id ]["segmentation" ]
153+ v0_areas = []
154+ v1_areas = []
155+ miou_sum = 0.0
156+ miou_count = 0
157+ for k0 in ref_masks :
158+ assert k0 in ret_data , f"Expected { k0 } to be in return data"
159+ from torchao ._models .sam2 .utils .amg import area_from_rle
160+ v0_area = area_from_rle (ref_masks [k0 ])
161+ v1_area = area_from_rle (ret_data [k0 ])
162+ v0_areas .append (v0_area )
163+ v1_areas .append (v1_area )
164+ if v0_area != v1_area :
165+ print (f"v0 area { v0_area } doesn't match v1 area { v1_area } " )
166+ v0_mask = torch .from_numpy (rle_to_mask (ref_masks [k0 ]))
167+ v1_mask = torch .from_numpy (rle_to_mask (ret_data [k0 ]))
168+ if not torch .allclose (v0_mask , v1_mask ):
169+ miou_sum += iou (v0_mask , v1_mask )
170+ miou_count += 1
171+ print (f"Masks don't match for key { k0 } . IoU is { iou (v0_mask , v1_mask )} " )
172+ if miou_count == 0 :
173+ print ("Masks exactly match reference." )
174+ else :
175+ print (f"mIoU is { miou_sum / miou_count } " )
176+
110177 if benchmark :
111178 logging .info (f"Running 3 warumup iterations." )
112179 for _ in range (3 ):
@@ -121,7 +188,13 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa
121188 max_memory_allocated_percentage = int (100 * (max_memory_allocated_bytes / total_memory ))
122189 max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
123190 print (f"max_memory_allocated_bytes: { max_memory_allocated_bytes } MiB or { max_memory_allocated_percentage } %" )
124- return
191+
192+ if profile is not None :
193+ print (f"Saving profile under { profile } " )
194+ profiler_runner (profile , mask_generator .generate , example_image )
195+
196+ if dry :
197+ return
125198
126199 app = FastAPI ()
127200
@@ -133,6 +206,25 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa
133206 allow_methods = ["*" ],
134207 allow_headers = ["*" ],
135208 )
209+
210+ @app .post ("/upload_rle" )
211+ async def upload_rle (image : UploadFile = File (...)):
212+ # Save the uploaded image to a temporary location
213+ temp_file = tempfile .NamedTemporaryFile (delete = False , suffix = f"_{ image .filename } " )
214+ with open (temp_file .name , "wb" ) as b :
215+ shutil .copyfileobj (image .file , b )
216+
217+ # Read the image back into memory to send as response
218+ example_image = cv2 .imread (temp_file .name )
219+ example_image = cv2 .cvtColor (example_image , cv2 .COLOR_BGR2RGB )
220+ t = time .time ()
221+ with torch .backends .cuda .sdp_kernel (enable_cudnn = True ):
222+ masks = mask_generator .generate (example_image )
223+ print (f"Took { time .time () - t } to generate a mask for input image." )
224+ ret_data = {}
225+ for mask_id in range (len (masks )):
226+ ret_data [f"mask_{ mask_id } " ] = masks [mask_id ]["segmentation" ]
227+ return ret_data
136228
137229 @app .post ("/upload" )
138230 async def upload_image (image : UploadFile = File (...)):
@@ -143,13 +235,16 @@ async def upload_image(image: UploadFile = File(...)):
143235
144236 # Read the image back into memory to send as response
145237 example_image = cv2 .imread (temp_file .name )
238+ example_image = cv2 .cvtColor (example_image , cv2 .COLOR_BGR2RGB )
146239 t = time .time ()
147240 with torch .backends .cuda .sdp_kernel (enable_cudnn = True ):
148241 masks = mask_generator .generate (example_image )
149242 print (f"Took { time .time () - t } to generate a mask for input image." )
150243 # Save an example
151244 plt .figure (figsize = (example_image .shape [1 ]/ 100. , example_image .shape [0 ]/ 100. ), dpi = 100 )
152245 plt .imshow (example_image )
246+ for i in range (len (masks )):
247+ masks [i ]["segmentation" ] = rle_to_mask (masks [i ]["segmentation" ])
153248 show_anns (masks )
154249 plt .axis ('off' )
155250 plt .tight_layout ()
@@ -163,7 +258,7 @@ async def upload_image(image: UploadFile = File(...)):
163258 return StreamingResponse (BytesIO (image_data ), media_type = "image/png" )
164259
165260
166- uvicorn .run (app , host = "127.0.0.1" , port = 5000 , log_level = "info" )
261+ uvicorn .run (app , host = host , port = port , log_level = "info" )
167262
168263if __name__ == "__main__" :
169264 fire .Fire (main )
0 commit comments