From 4f1fc4c2d5eace4153c6e62123f5584e772fff4c Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Tue, 29 Oct 2024 19:42:36 -0700 Subject: [PATCH] SAM2 AMG example mIoU, perf numbers and more SAM2 model annotations (#1196) --- examples/sam2_amg_server/README.md | 98 ++++++++++++++ examples/sam2_amg_server/compare_rle_lists.py | 41 ++++++ examples/sam2_amg_server/dog_rle.json | 1 + examples/sam2_amg_server/server.py | 113 ++++++++++++++-- .../_models/sam2/automatic_mask_generator.py | 128 ++++++++++-------- .../_models/sam2/modeling/sam/mask_decoder.py | 5 +- torchao/_models/sam2/sam2_image_predictor.py | 42 +++--- 7 files changed, 345 insertions(+), 83 deletions(-) create mode 100644 examples/sam2_amg_server/compare_rle_lists.py create mode 100644 examples/sam2_amg_server/dog_rle.json diff --git a/examples/sam2_amg_server/README.md b/examples/sam2_amg_server/README.md index ecbd1bd37..07314a5d7 100644 --- a/examples/sam2_amg_server/README.md +++ b/examples/sam2_amg_server/README.md @@ -2,3 +2,101 @@ ``` curl -X POST http://127.0.0.1:5000/upload -F 'image=@/path/to/file.jpg' --output path/to/output.png ``` + +## Example script to collect rles + +Start the server + +``` +python server.py ~/checkpoints/sam2 --port --host --fast +``` + +Collect the rles + +``` +xargs -I {} curl -s -w "\n" -X POST http://:/upload_rle -F 'image=@{}' < image_paths > rle_masks +``` + +## mIoU scores on random subset of sav validation dataset + +Experiments run on H100 and with batch size 1 + +| mode | mIoU | mask count mismatch | avg. ms per request | +| --- |--- | ------------------ | ----------------- | +| baseline | 1.0 | 0 | 786 | +| ao | 1.0 | 0 | 738 | +| fast | 0.95 | 190 | 563 | +| furious | 0 | 1000 | 204 | + +mask count mismatch counts the number of requests where the number of masks differ from the baseline. +For example, the baseline may have chosen to segment an image into 18 masks, but the fast variant produces 17 or 19. +We exclude these examples from the mIoU calculation. + +### 1. Create a random subset of 1000 images +``` +find sav_val -type f > sav_val_image_paths +shuf -n 1000 sav_val_image_paths > sav_val_image_paths_shuf_1000 +``` + +### 2. Use the baseline (https://github.com/facebookresearch/sam2) to generate rles + +Make sure you've installed https://github.com/facebookresearch/sam2 + +Start server +``` +python server.py ~/checkpoints/sam2 --port --host --baseline +``` + +Generate and save rles (one line per json via `-w "\n"`) +``` +$ time xargs -I {} curl -s -w "\n" -X POST http://:/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_baseline_shuf_1000 + +real 13m6.374s +user 0m3.349s +sys 0m4.137s +``` + +### 3. Start server with torchao variant of SAM2 +Start server +``` +python server.py ~/checkpoints/sam2 --port --host +``` + +Generate and save rles (one line per json via `-w "\n"`) +``` +$ time xargs -I {} curl -s -w "\n" -X POST http://:/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_shuf_1000 + +real 12m18.916s +user 0m3.506s +sys 0m4.350s +``` + +### 4. Start server with torchao variant of SAM2 and `--fast` optimizations +Start server +``` +python server.py ~/checkpoints/sam2 --port --host --fast +``` + +Generate and save rles (one line per json via `-w "\n"`) +``` +$ time xargs -I {} curl -s -w "\n" -X POST http://:/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_fast_shuf_1000 + +real 9m23.912s +user 0m3.271s +sys 0m4.138s +``` + +### 5. Start server with torchao variant of SAM2 and `--fast` and `--furious` optimizations +Start server +``` +python server.py ~/checkpoints/sam2 --port --host --fast --furious +``` + +Generate and save rles (one line per json via `-w "\n"`) +``` +$ time xargs -I {} curl -s -w "\n" -X POST http://:/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_fast_furious_shuf_1000 + +real 3m24.383s +user 0m3.583s +sys 0m4.519s +``` diff --git a/examples/sam2_amg_server/compare_rle_lists.py b/examples/sam2_amg_server/compare_rle_lists.py new file mode 100644 index 000000000..51a824f1f --- /dev/null +++ b/examples/sam2_amg_server/compare_rle_lists.py @@ -0,0 +1,41 @@ +import fire +import torch +import json +from sam2.utils.amg import rle_to_mask + +""" +Script to calculate mIoU given two lists of rles from upload_rle endpoint +of server. +""" + + +def iou(mask1, mask2): + assert mask1.dim() == 2 + assert mask2.dim() == 2 + intersection = torch.logical_and(mask1, mask2) + union = torch.logical_or(mask1, mask2) + return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2))) + + +def main(path0, path1): + fail_count = 0 + miou_sum = 0.0 + miou_count = 0 + with open(path0, 'r') as f0, open(path1, 'r') as f1: + for line0, line1 in zip(f0, f1): + masks0 = json.loads(line0) + masks1 = json.loads(line1) + if masks0.keys() != masks1.keys(): + fail_count += 1 + continue + for mask0, mask1 in zip(masks0.values(), masks1.values()): + mask0 = torch.from_numpy(rle_to_mask(mask0)) + mask1 = torch.from_numpy(rle_to_mask(mask1)) + miou_sum += iou(mask0, mask1).item() + miou_count += 1 + + print(f"fail_count: {fail_count} mIoU: {miou_sum / miou_count}") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/sam2_amg_server/dog_rle.json b/examples/sam2_amg_server/dog_rle.json new file mode 100644 index 000000000..b6932dfc1 --- /dev/null +++ b/examples/sam2_amg_server/dog_rle.json @@ -0,0 +1 @@ +{"mask_0":{"size":[534,800],"counts":[441,93,441,93,440,94,439,95,439,95,438,96,437,97,436,98,435,99,434,100,433,101,432,102,432,102,431,103,430,104,429,105,428,106,427,107,426,108,425,109,424,110,424,110,423,111,422,112,421,113,421,113,420,114,419,115,418,116,417,117,416,118,416,118,415,119,414,120,413,121,413,121,412,122,412,122,411,123,410,124,410,124,409,125,408,126,407,127,407,127,406,128,405,129,404,130,403,131,402,132,401,133,400,134,399,135,398,136,397,137,396,138,395,139,394,140,393,141,391,143,390,144,389,145,388,146,387,147,386,148,385,149,383,151,382,152,381,153,380,154,379,155,378,156,377,157,375,159,374,160,373,161,372,162,371,163,370,164,368,166,367,167,367,167,366,168,364,170,363,171,362,172,361,173,360,174,359,175,358,176,358,176,357,177,355,179,355,179,354,180,353,181,352,182,351,183,351,183,350,184,350,184,349,185,348,186,348,186,347,187,346,188,345,189,344,190,343,191,342,192,341,193,339,195,338,196,338,196,337,197,335,199,334,200,334,200,333,201,332,202,331,203,330,204,329,205,328,206,327,207,326,208,325,209,325,209,324,210,322,212,321,213,321,213,320,214,319,215,318,216,317,217,317,217,316,218,316,218,315,219,314,220,313,221,313,221,312,222,311,223,311,223,310,224,309,225,309,225,309,225,309,225,309,225,308,226,308,226,308,226,308,226,307,227,307,227,307,227,306,228,306,228,305,229,304,230,304,230,304,230,303,231,302,232,302,232,301,233,300,234,300,234,299,235,298,236,298,236,297,237,297,237,296,238,296,238,295,239,295,239,294,240,293,241,292,242,291,243,290,244,289,245,288,246,288,246,287,247,287,247,287,247,286,248,285,249,254,5,26,249,251,8,1,3,20,251,250,14,18,252,249,15,18,252,247,17,17,253,245,20,16,253,244,21,17,252,243,23,19,249,242,24,21,247,240,27,19,248,239,28,18,249,237,30,18,249,235,32,16,251,234,33,15,252,232,36,13,253,230,38,13,253,229,39,12,254,228,40,11,255,227,41,10,256,226,42,9,257,225,44,7,258,224,47,4,259,224,310,223,311,223,311,224,310,224,310,224,310,225,309,226,308,226,308,226,308,227,307,228,306,228,306,229,305,229,305,230,304,231,303,232,302,232,302,232,302,233,301,233,301,233,301,234,300,234,300,234,300,235,299,235,299,235,299,235,299,235,299,235,299,235,299,236,298,236,298,236,298,236,298,237,297,237,297,237,297,237,297,237,297,237,297,237,297,237,297,237,297,237,295,238,294,240,286,248,282,252,278,255,276,258,273,261,270,264,268,266,264,270,260,273,258,276,255,279,252,282,249,285,246,287,243,291,238,296,229,305,221,313,213,320,204,330,196,338,183,2,6,343,181,354,179,355,179,355,179,355,178,356,178,356,177,357,177,356,178,356,177,357,177,357,176,358,176,358,176,358,176,358,175,359,175,359,175,359,174,360,174,360,173,361,173,361,173,361,172,362,172,362,172,362,171,363,171,363,170,364,170,364,170,364,169,365,169,365,169,365,168,366,168,366,168,366,168,365,168,366,167,367,167,367,167,367,166,369,164,370,164,370,164,370,163,371,162,373,161,373,160,374,159,376,157,377,156,379,155,380,153,382,151,384,149,386,147,388,145,390,143,393,140,396,137,398,135,400,133,403,130,405,128,407,126,409,124,411,122,412,122,413,119,416,117,418,115,420,114,420,113,421,113,421,112,422,112,422,112,421,113,422,112,422,112,422,112,422,112,422,112,422,112,423,111,423,111,423,111,423,111,424,110,424,110,424,110,424,110,424,110,424,109,425,109,425,109,426,107,427,107,428,106,429,104,431,103,432,102,434,99,436,96,440,94,441,92,445,87,449,84,451,82,453,79,456,77,457,76,461,72,464,68,467,66,470,62,473,58,477,55,481,48,490,43,494,39,496,38,500,34,502,31,505,29,511,23,513,21,515,18,517,16,519,15,519,15,520,13,522,11,524,10,525,8,527,6,529,5,200989]},"mask_1":{"size":[534,800],"counts":[262387,9,521,16,516,19,513,24,509,26,507,28,505,30,502,33,500,36,497,38,495,40,493,42,491,44,489,47,486,49,484,53,480,57,477,58,475,60,474,62,471,67,466,70,463,72,462,75,458,77,457,79,454,82,451,84,450,85,448,52,1,34,447,52,1,34,446,52,2,35,445,52,2,36,443,52,3,37,441,53,4,37,440,53,4,38,438,54,4,38,438,53,5,39,437,53,6,38,436,53,7,39,435,53,7,39,434,53,8,40,433,53,8,40,433,53,8,40,432,54,8,41,430,54,9,41,430,54,9,41,430,54,9,42,428,55,10,41,428,54,11,41,428,54,11,41,427,54,12,42,426,54,12,42,426,53,12,43,425,53,13,43,425,52,14,43,425,52,14,43,424,53,14,44,423,52,15,44,423,52,15,44,422,52,16,44,422,52,16,44,422,51,17,44,422,51,17,44,421,52,17,44,421,52,17,44,421,52,17,44,420,53,17,44,420,53,17,44,420,53,17,44,419,54,17,44,419,54,17,44,419,54,17,44,418,55,16,45,418,55,16,45,418,55,16,45,418,55,16,45,417,56,16,45,417,56,16,45,417,56,15,46,417,56,15,46,416,57,15,46,416,57,15,46,416,57,15,46,416,57,15,46,415,57,16,46,415,57,16,46,415,57,15,46,416,57,15,46,416,57,15,46,415,58,15,46,415,59,14,46,415,59,13,47,415,60,12,47,415,60,12,47,415,61,11,46,416,61,10,47,415,63,9,47,415,64,7,48,415,65,6,48,415,66,5,47,416,66,4,48,416,66,3,49,416,66,3,49,415,67,3,49,415,67,3,48,416,118,416,118,416,117,417,117,417,117,417,117,417,116,418,116,417,117,417,116,418,116,418,116,418,116,418,116,418,115,419,115,419,114,420,113,421,113,421,113,421,112,422,112,422,112,422,111,423,111,423,110,425,109,425,109,425,108,426,107,427,107,437,96,439,94,442,92,443,90,444,90,445,88,446,88,447,86,448,85,449,84,450,83,451,82,451,82,452,81,453,80,453,80,453,80,454,79,454,79,455,77,456,77,457,76,459,73,461,72,463,69,465,67,468,61,474,28,5,23,479,22,514,17,518,12,75657]},"mask_2":{"size":[534,800],"counts":[187954,1,2,2,527,4,1,2,527,4,1,2,527,4,1,2,527,4,1,2,529,1,2,2,532,2,532,2,531,3,531,3,68346,7,2,3,502,2,1,8,1,22,495,39,491,43,490,44,490,45,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,489,45,489,45,489,45,489,45,489,45,489,45,489,45,489,45,489,45,490,44,490,44,490,44,490,44,490,44,490,44,490,44,490,44,490,44,490,44,490,44,490,44,491,43,491,43,491,43,491,43,491,43,491,43,491,43,491,43,491,43,491,43,492,42,492,42,492,42,492,42,492,42,492,42,492,42,492,42,492,42,492,42,492,42,492,42,492,42,493,41,493,41,493,41,493,41,493,41,493,41,493,41,493,41,493,41,493,41,494,40,494,40,494,40,494,40,494,40,494,40,494,40,494,40,494,40,494,40,494,40,495,39,495,39,495,39,495,39,495,39,495,39,495,39,495,39,495,39,495,39,495,38,497,37,497,37,497,37,497,37,497,37,497,37,497,37,497,37,497,37,497,37,497,37,498,36,498,36,498,36,498,36,498,37,497,37,497,37,497,37,497,37,497,37,497,37,497,37,497,37,497,37,497,37,497,37,498,36,498,36,498,36,498,36,498,36,498,36,498,36,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,500,34,500,34,500,34,500,34,500,34,500,34,500,34,500,34,500,34,501,33,501,33,501,33,501,33,501,33,501,33,501,33,501,33,501,33,501,33,501,33,501,33,501,33,501,33,501,33,502,32,502,32,502,32,502,32,502,32,502,32,502,32,502,32,502,32,502,32,502,32,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,504,30,504,30,504,30,504,30,504,30,504,30,504,30,505,29,505,29,504,30,504,30,505,29,505,29,505,29,505,29,505,29,505,29,505,29,505,29,505,29,505,29,505,29,506,28,506,28,506,28,506,28,506,28,506,28,506,28,506,28,506,28,506,28,506,28,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,508,26,508,26,508,26,508,26,508,26,508,26,508,26,508,26,508,26,508,26,508,26,508,26,509,25,509,25,509,25,509,25,509,25,509,25,509,25,509,25,509,25,509,25,509,25,509,25,509,25,510,24,510,24,510,24,510,24,510,24,510,24,510,24,511,23,511,23,511,23,511,23,511,23,511,23,511,23,511,23,511,23,511,23,511,23,511,23,511,23,512,22,512,22,512,22,512,22,511,23,512,22,512,22,512,22,512,22,512,22,513,21,513,21,513,21,513,21,513,21,513,21,513,21,513,21,513,21,513,21,513,21,513,21]},"mask_3":{"size":[534,800],"counts":[0,58,476,58,476,57,477,57,477,56,478,55,479,55,479,54,480,54,480,54,480,53,481,52,482,51,483,51,483,50,484,50,484,50,484,49,485,49,485,48,486,48,486,48,486,48,486,47,487,47,487,47,487,46,488,46,488,45,489,45,489,44,490,44,490,43,491,42,492,42,492,41,493,36,498,27,507,22,512,18,516,16,518,14,520,12,522,10,524,9,525,7,527,6,528,5,529,4,530,3,531,2,400498]},"mask_4":{"size":[534,800],"counts":[261110,14,501,1,3,31,494,40,491,43,489,45,489,45,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,488,46,489,45,489,45,489,45,489,45,489,45,489,45,489,45,489,45,489,45,490,44,490,44,490,44,490,44,490,44,490,44,490,44,490,44,490,44,490,44,490,44,490,44,490,44,491,43,491,43,491,43,491,43,491,43,491,43,491,43,491,43,491,43,491,43,491,43,492,42,492,42,492,42,492,42,492,42,492,42,492,42,492,42,492,42,492,42,492,42,492,42,492,42,493,41,493,41,493,41,493,41,493,41,493,41,493,41,493,41,493,41,493,41,494,40,494,40,494,40,494,40,494,40,494,40,494,40,494,40,494,40,494,40,494,40,495,39,495,39,495,39,495,39,495,39,495,39,495,39,495,39,495,39,495,39,495,39,495,39,495,39,495,39,496,38,496,38,496,38,496,38,496,38,496,38,496,38,496,38,497,36,497,38,496,38,496,38,496,38,497,37,497,37,497,37,497,37,497,37,497,37,497,37,497,37,497,37,497,37,498,36,498,36,498,36,498,36,498,36,498,36,498,36,499,35,499,35,499,35,499,35,499,35,498,36,498,36,498,36,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,499,35,500,34,500,34,500,34,500,34,500,34,500,34,500,34,500,34,501,33,501,33,501,33,501,33,500,34,501,33,501,33,501,33,501,33,501,33,501,33,501,33,501,33,502,32,502,32,502,32,502,32,502,32,502,32,502,32,502,32,502,32,502,32,502,32,502,32,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,503,31,504,30,504,30,504,30,504,30,504,30,504,30,504,30,504,30,504,30,504,30,504,30,505,29,505,29,505,29,505,29,505,28,506,28,506,28,506,28,506,28,506,28,506,29,506,28,506,28,506,28,506,28,506,28,506,28,506,28,506,28,506,28,506,28,506,28,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,507,27,508,26,508,26,508,26,508,26,508,26,508,26,508,26,508,26,508,26,508,25,509,26,508,26,509,25,509,25,509,25,509,25,509,25,509,25,509,25,509,25,509,25,509,25,509,24,510,24,510,24,510,25,509,25,509,25,510,24,510,24,510,24,510,23,511,24,510,24,510,24,511,22,512,22,512,23,511,23,511,23,511,23,511,23,511,23,511,22,512,22,512,22,512,22,512,22,512,22,513,21,513,21,513,21,512,23,512,22,512,21,513,21,513,21,514,20,514,20,514,21,513,21,513,21,513,21,513,21,513,21,513,21,513,21,513,20,514,20,514,20,1]},"mask_5":{"size":[534,800],"counts":[141501,9,520,14,516,18,513,21,510,24,507,27,505,29,501,33,497,37,494,40,490,44,487,47,484,50,482,52,478,56,475,59,471,63,468,66,466,68,466,68,466,68,466,68,466,68,467,67,467,67,467,67,467,67,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,469,65,469,65,469,65,469,65,469,65,469,65,469,65,470,64,470,64,470,64,470,64,470,64,470,64,470,64,470,64,470,64,470,64,471,63,471,63,471,63,471,63,471,63,471,63,471,63,471,63,471,63,399,1,71,63,398,2,71,63,398,2,72,62,398,2,72,62,472,62,472,62,472,62,473,61,472,62,472,62,472,62,473,61,473,61,389,1,83,61,473,61,473,61,386,2,85,61,473,61,473,61,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,477,57,477,57,477,57,477,57,477,57,477,57,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,479,55,479,55,479,55,479,55,479,55,479,55,479,55,479,55,479,55,480,54,480,54,480,54,480,54,480,54,480,54,480,54,480,54,481,53,481,53,481,53,481,53,481,53,481,53,481,53,481,53,482,52,482,52,482,52,482,52,482,52,482,52,482,52,482,52,482,52,482,52,482,52,482,52,483,51,483,51,483,51,483,51,483,51,483,51,483,51,483,51,483,51,483,51,484,50,484,50,483,51,483,51,484,50,484,50,484,50,484,50,484,50,484,50,484,50,484,50,484,50,484,50,484,50,484,50,484,50,485,49,485,49,485,49,485,49,485,49,485,49,485,49,485,49,486,48,486,48,486,48,486,48,486,48,486,48,486,48,486,48,486,48,486,48,486,46,489,37,2,1,495,23,512,6,6,1,167174]},"mask_6":{"size":[534,800],"counts":[202920,2,532,3,531,5,529,8,526,10,524,12,522,15,519,20,514,25,509,29,505,37,497,41,493,42,492,44,490,49,485,51,483,52,482,53,481,54,480,55,479,55,479,55,479,55,479,55,479,55,479,55,479,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,58,476,57,477,57,477,57,477,57,477,57,477,58,476,58,476,58,476,58,476,58,476,58,476,58,476,4,8,46,490,44,491,43,491,43,491,43,491,43,490,44,487,47,485,49,481,1,2,50,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,61,473,61,473,61,473,61,473,60,474,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,61,473,62,472,59,475,51,483,50,484,41,3,2,488,34,500,5,1,3,1,3,1,2,103046]},"mask_7":{"size":[534,800],"counts":[141500,10,519,15,516,18,513,21,510,24,507,27,505,29,501,33,497,37,494,40,490,44,487,47,484,50,482,52,478,56,475,59,472,62,468,66,467,67,467,67,467,67,467,67,467,67,467,67,467,67,467,67,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,468,66,469,65,469,65,469,65,469,65,469,65,469,65,469,65,469,65,469,65,469,65,470,64,470,64,470,64,470,64,470,64,470,64,470,64,470,64,470,64,470,64,470,64,471,63,471,63,471,63,471,63,471,63,471,63,471,63,471,63,471,63,471,63,472,62,472,62,472,62,472,62,472,62,472,62,472,62,472,62,472,62,472,62,473,61,473,61,473,61,473,61,473,61,473,61,473,61,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,474,60,475,59,475,59,474,60,475,59,475,59,475,59,475,59,475,59,475,59,475,59,475,59,476,58,476,58,476,58,476,58,476,58,476,58,476,58,476,58,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,477,57,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,478,56,479,55,479,55,479,55,479,55,479,55,479,55,479,55,479,55,480,54,480,54,480,54,480,54,480,54,480,54,480,54,480,54,480,54,481,53,481,53,481,53,481,53,481,53,481,53,481,53,481,53,482,52,482,52,482,52,482,52,482,52,482,52,482,52,482,52,482,52,482,52,482,52,482,52,483,51,483,51,483,51,483,51,483,51,483,51,483,51,483,51,483,51,483,51,484,50,484,50,484,50,483,51,484,50,484,50,484,50,484,50,484,50,484,50,484,50,484,50,484,50,484,50,484,50,484,50,484,50,485,49,485,49,485,49,485,49,485,49,485,49,485,49,485,49,485,49,486,48,486,48,486,48,486,48,486,48,486,48,486,48,486,48,486,48,487,43,491,32,503,15,167707]},"mask_8":{"size":[534,800],"counts":[37954,9,519,22,507,32,499,39,491,49,483,54,478,60,473,63,470,66,465,71,463,73,460,75,458,77,457,78,455,81,452,82,452,83,450,85,448,87,447,87,446,89,445,89,444,91,443,91,442,92,442,93,441,93,441,93,440,94,439,95,439,96,438,96,438,96,438,96,438,96,438,96,438,96,438,96,438,96,438,96,439,95,439,94,441,93,441,92,442,92,442,91,443,89,447,86,452,80,455,77,459,72,464,68,467,64,471,58,477,45,491,26,510,8,359352]},"mask_9":{"size":[534,800],"counts":[168,33,501,33,501,32,502,31,503,30,504,30,504,29,505,28,506,28,506,27,508,25,509,25,509,24,510,23,511,23,511,22,512,21,513,20,514,20,514,19,515,18,516,18,516,17,517,16,518,16,518,15,519,14,520,14,520,13,521,12,522,11,523,11,523,10,524,9,525,9,525,8,526,7,527,7,527,6,528,5,529,4,531,2,405134]},"mask_10":{"size":[534,800],"counts":[272698,3,527,9,523,13,520,16,517,18,515,21,511,24,508,27,506,28,505,29,504,31,502,32,502,32,501,33,500,34,499,35,499,35,498,36,498,36,498,36,497,37,497,37,497,37,497,37,496,38,496,37,497,37,496,38,496,38,496,37,497,37,496,38,496,38,496,38,496,37,497,37,496,38,496,38,495,38,496,38,495,39,495,39,495,38,495,39,495,38,496,38,496,37,496,38,496,37,496,38,496,37,496,38,496,37,496,38,496,37,496,38,495,38,495,39,495,38,495,39,495,38,495,39,494,39,494,40,493,40,493,41,492,41,492,42,491,42,489,45,484,50,477,56,470,64,465,68,461,73,454,79,449,85,445,89,444,90,444,90,444,89,398,1,46,89,65,5,328,1,45,90,59,13,326,1,44,90,59,17,322,2,44,90,55,22,321,3,43,90,54,23,321,2,44,89,53,26,320,2,43,90,51,28,242,1,77,2,43,90,48,32,233,10,121,90,45,35,214,29,120,91,43,37,202,41,120,91,41,39,192,51,120,90,40,41,180,63,120,90,39,42,168,75,120,90,38,44,160,82,119,91,36,46,155,87,118,92,35,48,150,91,118,91,33,52,106,5,5,124,117,92,32,53,102,138,117,92,31,55,101,138,117,91,30,57,101,138,117,91,27,61,100,138,116,91,27,62,100,138,116,91,25,64,100,138,116,91,24,65,100,138,116,90,23,67,100,138,115,90,20,71,100,138,115,88,18,75,100,138,114,87,16,79,100,138,113,63,4,15,20,81,100,138,113,60,40,83,100,138,113,56,42,85,100,138,112,54,42,88,100,138,112,50,45,88,101,138,111,48,46,90,101,138,110,47,45,93,101,138,110,42,48,94,102,139,108,38,52,95,102,139,108,33,55,97,102,140,3,6,97,30,57,99,102,151,95,27,56,102,103,152,94,22,59,104,103,153,92,18,62,106,103,153,92,14,63,109,103,154,91,10,65,110,104,154,91,7,66,112,104,154,92,4,67,113,104,154,161,114,105,154,159,116,105,153,159,116,106,153,156,119,106,153,153,122,106,153,150,124,107,153,149,125,107,152,148,126,108,151,148,126,109,150,148,127,109,150,147,127,110,149,147,127,111,149,144,129,112,149,142,130,113,149,140,131,114,149,138,131,116,150,134,131,119,150,130,133,121,153,123,132,126,154,26,14,77,128,135,155,24,22,66,125,142,157,19,23,63,124,148,160,12,26,56,129,151,165,1,30,50,134,154,194,44,138,158,191,40,138,165,190,37,134,173,187,36,133,178,185,33,136,180,182,32,138,182,180,30,140,184,178,27,143,186,176,27,144,187,175,25,145,189,173,24,146,191,170,24,146,194,168,23,147,196,167,21,148,198,165,22,147,200,164,20,148,202,164,15,151,204,165,10,153,206,326,208,324,210,323,211,322,212,319,215,317,217,315,219,313,221,312,222,311,223,309,225,307,227,306,228,304,230,303,231,302,232,299,235,298,236,296,238,295,239,293,241,292,242,290,244,288,246,287,247,285,249,283,251,282,252,280,254,278,256,276,258,274,260,273,261,272,262,270,264,269,265,267,267,265,269,263,271,262,272,260,274,257,277,255,279,252,282,40,1,209,284,248,286,246,288,243,291,241,293,238,296,235,299,233,301,230,304,227,307,224,310,221,313,217,317,214,320,211,323,208,326,203,331,199,335,196,338,192,342,188,346,184,350,178,356,174,360,170,364,163,371,159,375,154,380,149,385,145,389,136,398,131,403,127,407,124,410,121,413,117,417,115,419,112,422,109,425,107,427,104,430,101,433,99,435,96,438,94,440,92,442,90,444,86,448,85,449,83,451,80,454,78,456,76,458,74,460,72,462,70,464,68,466,66,468,64,470,62,472,59,475,57,477,55,479,54,480,52,482,50,484,48,486,46,488,44,490,42,492,41,493,40,494,37,497,35,499,34,500,31,503,31,503]},"mask_11":{"size":[534,800],"counts":[58,58,476,58,476,58,476,57,476,57,477,56,477,56,477,57,477,57,476,57,477,57,476,57,476,57,477,57,476,57,477,57,476,57,477,57,477,56,477,57,477,57,477,57,476,57,477,57,477,57,477,57,476,57,477,57,476,58,476,58,475,59,475,59,474,60,474,60,474,60,475,59,479,55,487,47,495,38,502,32,507,27,511,23,515,19,519,15,525,9,530,3,403068]},"mask_12":{"size":[534,800],"counts":[276,40,494,40,494,39,495,38,496,37,498,35,499,34,500,32,502,32,502,31,503,30,504,29,506,27,507,26,508,26,508,24,510,24,510,23,511,22,512,21,513,19,515,19,515,18,516,17,517,16,518,15,519,14,520,13,521,12,522,12,522,11,523,10,524,9,525,8,526,7,528,5,529,4,1063,1,406629]},"mask_13":{"size":[534,800],"counts":[336031,5,526,9,521,15,517,18,514,20,514,20,513,22,509,25,506,28,505,29,503,30,502,32,500,34,497,37,494,39,492,41,490,43,489,45,486,47,484,49,482,52,481,54,480,54,480,54,481,54,480,55,480,55,479,56,30,2,446,58,23,21,432,60,18,24,433,61,13,25,435,98,436,96,438,93,441,90,444,87,447,85,450,82,452,80,454,79,455,77,458,74,460,72,462,70,464,68,467,64,471,61,473,59,476,57,478,54,480,53,482,50,485,47,487,45,490,37,497,31,504,25,509,18,517,11,524,4,59693]},"mask_14":{"size":[534,800],"counts":[58,58,476,58,475,59,475,58,475,58,475,58,476,57,476,58,476,57,476,57,477,57,476,57,477,57,476,57,477,57,476,58,476,57,477,57,476,57,477,57,477,57,476,57,477,57,477,57,476,58,476,58,476,57,476,58,476,58,475,59,475,59,474,60,473,61,473,61,473,61,474,60,479,55,486,47,494,40,501,33,506,28,511,23,515,19,519,15,525,8,530,3,19785,1,533,1,1065,1,381683]},"mask_15":{"size":[534,800],"counts":[272695,8,524,11,521,15,518,18,514,21,512,24,509,25,508,27,506,29,504,30,502,32,501,33,501,34,499,35,498,36,498,36,497,37,497,37,496,38,496,38,496,37,496,38,496,38,496,38,495,39,495,38,496,38,496,38,495,39,495,39,495,38,496,38,496,38,495,39,495,38,496,38,495,39,495,38,496,38,495,39,495,39,495,38,495,39,495,39,495,38,495,39,495,38,495,39,495,38,496,38,495,38,495,39,495,38,495,39,495,38,495,39,495,38,495,39,494,39,494,40,493,40,494,40,493,40,493,40,493,41,492,42,491,42,491,43,490,43,489,45,488,46,488,45,489,45,489,44,490,44,491,43,491,42,493,41,493,41,493,40,494,40,495,39,495,39,495,38,496,38,496,38,496,38,496,37,498,36,498,36,498,36,498,36,498,36,498,36,499,34,500,34,500,34,500,34,500,34,500,33,501,33,501,32,503,31,503,31,503,30,504,29,505,28,506,27,508,23,516,13,96320]},"mask_16":{"size":[534,800],"counts":[278963,1,531,3,531,4,530,4,530,4,530,5,528,6,528,6,527,7,527,7,527,7,526,8,526,9,525,9,524,10,524,10,524,10,523,11,523,11,523,11,522,12,521,13,521,13,520,14,519,15,519,15,518,16,518,16,518,16,518,16,517,17,517,16,518,16,517,17,517,17,517,17,517,17,517,17,517,17,517,17,517,17,517,17,517,16,518,16,518,16,518,16,518,16,518,16,518,16,519,15,519,15,519,15,518,15,519,15,518,16,518,16,518,16,518,15,519,15,519,15,520,14,520,14,521,12,523,11,524,10,524,9,525,9,526,8,526,7,527,7,528,6,528,5,530,4,530,3,532,2,108726]},"mask_17":{"size":[534,800],"counts":[34832,2,532,3,531,4,530,4,530,5,529,6,528,6,527,8,526,9,525,10,524,10,524,11,523,11,523,11,523,11,523,11,523,11,523,11,523,11,523,10,524,10,524,10,524,10,524,10,525,8,526,8,526,8,527,6,528,6,530,3,376876]},"mask_18":{"size":[534,800],"counts":[316541,9,521,16,515,21,511,23,508,27,503,31,502,33,498,36,495,39,493,41,493,41,492,43,489,45,488,47,485,49,483,52,481,54,479,55,477,57,474,61,471,63,469,65,467,67,465,69,462,72,459,75,455,79,453,81,451,83,449,85,447,87,445,88,443,91,441,93,439,94,439,95,438,96,437,97,437,96,439,95,439,95,440,94,441,92,443,91,443,91,444,89,446,88,446,87,448,86,448,85,450,84,450,84,450,83,452,82,452,81,454,79,455,78,457,77,457,75,459,74,460,72,463,69,465,66,468,63,471,53,481,46,488,42,492,37,498,34,500,30,504,23,512,10,524,6,528,4,71737]},"mask_19":{"size":[534,800],"counts":[310552,8,517,19,509,25,503,32,498,36,489,46,486,48,486,48,485,49,485,50,484,50,483,51,483,52,482,52,481,53,481,53,481,53,481,54,353,6,120,55,346,13,120,55,319,41,118,56,311,49,118,56,305,54,119,56,296,6,2,55,118,57,287,72,118,58,278,80,118,58,271,87,117,59,222,6,9,8,9,9,2,93,117,59,220,32,1,105,117,59,220,138,116,60,220,138,116,60,220,138,116,60,220,138,116,60,220,138,116,60,220,138,115,61,220,138,115,61,220,138,114,62,220,138,114,61,221,138,113,61,222,138,113,58,225,138,112,57,227,138,112,54,230,138,111,52,233,138,111,50,235,138,110,47,239,138,110,43,243,138,109,42,245,138,109,34,253,138,108,31,257,137,109,29,259,136,110,22,60,6,200,135,110,19,62,9,199,134,111,15,62,14,198,132,113,10,66,16,197,130,115,8,67,18,196,128,117,6,67,20,196,127,119,3,68,22,195,125,189,25,195,123,188,29,194,121,188,31,194,118,191,32,193,116,192,33,193,113,189,40,192,110,191,42,191,107,193,43,191,103,197,44,190,99,133,1,67,44,190,86,146,1,7,1,59,45,189,84,212,49,189,82,210,53,189,79,211,56,188,78,209,59,188,77,209,60,188,76,208,62,188,76,204,66,188,76,199,72,187,75,193,79,187,73,187,87,187,71,184,92,187,70,184,93,187,68,174,105,187,63,169,115,187,62,164,121,187,60,162,126,186,58,160,130,186,56,160,132,186,48,164,136,186,40,165,141,188,37,165,143,189,33,166,143,192,30,167,144,193,27,169,143,195,22,172,143,197,19,170,146,199,16,170,146,202,12,172,147,203,9,170,151,204,6,159,1,10,152,206,3,158,7,6,152,366,166,365,168,365,167,365,166,362,171,358,174,349,183,346,187,343,189,342,191,341,191,340,192,341,192,341,192,342,190,344,188,345,188,346,186,347,186,347,185,345,187,346,186,346,187,346,187,340,1,2,189,336,196,335,197,336,196,335,197,336,197,336,196,336,197,336,196,337,195,336,197,334,199,325,207,324,208,321,211,323,208,323,209,323,209,322,210,324,208,322,210,323,208,324,208,321,211,318,213,318,213,316,216,318,212,321,210,320,212,319,211,323,208,325,207,324,206,326,204,329,200,332,199,335,195,339,192,342,186,348,182,352,177,357,173,361,168,366,164,370,158,376,151,383,147,387,142,392,135,399,131,403,126,408,123,411,118,416,117,417,114,420,111,423,109,425,105,429,101,433,100,434,97,437,93,441,92,442,91,443,86,448,84,450,84,450,82,452,80,454,77,457,72,462,67,467,67,467,67,467,66,468,65,469,60,474,59,475,58,476,56,478,55,479,52,482,50,484,42,492,42,492,42,492,42,492,36,1,4,493,36,1,3,494,36,498,34,500,34,500,32,502,30,504,30,504]}} \ No newline at end of file diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py index c00cc990e..91dd56f1e 100644 --- a/examples/sam2_amg_server/server.py +++ b/examples/sam2_amg_server/server.py @@ -30,6 +30,15 @@ # torch.set_float32_matmul_precision('high') + +def iou(mask1, mask2): + assert mask1.dim() == 2 + assert mask2.dim() == 2 + intersection = torch.logical_and(mask1, mask2) + union = torch.logical_or(mask1, mask2) + return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2))) + + def show_anns(anns): if len(anns) == 0: return @@ -49,17 +58,44 @@ def show_anns(anns): return torch.stack(ms) -def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=False, points_per_batch=64): +def profiler_runner(path, fn, *args, **kwargs): + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA], + record_shapes=True) as prof: + result = fn(*args, **kwargs) + prof.export_chrome_trace(path) + return result + + +def main(checkpoint_path, + baseline=False, + fast=False, + furious=False, + unittest=False, + benchmark=False, + profile=None, + verbose=False, + points_per_batch=64, + port=5000, + host="127.0.0.1", + dry=False): if verbose: - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') logging.info(f"Running with fast set to {fast} and furious set to {furious}") + logging.info(f"Running with port {port} and host {host}") - if fast: - from torchao._models.sam2.build_sam import build_sam2 - from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator - else: + if baseline: + logging.info(f"Importing sam2 from outside of torchao. If this errors, install https://github.com/facebookresearch/sam2") from sam2.build_sam import build_sam2 from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator + from sam2.utils.amg import rle_to_mask + else: + from torchao._models.sam2.build_sam import build_sam2 + from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator + from torchao._models.sam2.utils.amg import rle_to_mask device = "cuda" from pathlib import Path @@ -70,7 +106,7 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) logging.info(f"Using {points_per_batch} points_per_batch") - mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch) + mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle") if furious: torch.set_float32_matmul_precision('high') @@ -107,6 +143,37 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa logging.info(f"Running one iteration to compile.") masks = mask_generator.generate(example_image) logging.info(f"First iteration took {time.time() - t}s.") + if unittest: + logging.info(f"Running strict comparison to reference mask") + import json + ref_masks = json.loads(open("dog_rle.json").read()) + ret_data = {} + for mask_id in range(len(masks)): + ret_data[f"mask_{mask_id}"] = masks[mask_id]["segmentation"] + v0_areas = [] + v1_areas = [] + miou_sum = 0.0 + miou_count = 0 + for k0 in ref_masks: + assert k0 in ret_data, f"Expected {k0} to be in return data" + from torchao._models.sam2.utils.amg import area_from_rle + v0_area = area_from_rle(ref_masks[k0]) + v1_area = area_from_rle(ret_data[k0]) + v0_areas.append(v0_area) + v1_areas.append(v1_area) + if v0_area != v1_area: + print(f"v0 area {v0_area} doesn't match v1 area {v1_area}") + v0_mask = torch.from_numpy(rle_to_mask(ref_masks[k0])) + v1_mask = torch.from_numpy(rle_to_mask(ret_data[k0])) + if not torch.allclose(v0_mask, v1_mask): + miou_sum += iou(v0_mask, v1_mask) + miou_count += 1 + print(f"Masks don't match for key {k0}. IoU is {iou(v0_mask, v1_mask)}") + if miou_count == 0: + print("Masks exactly match reference.") + else: + print(f"mIoU is {miou_sum / miou_count}") + if benchmark: logging.info(f"Running 3 warumup iterations.") for _ in range(3): @@ -121,7 +188,13 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory)) max_memory_allocated_bytes = max_memory_allocated_bytes >> 20 print(f"max_memory_allocated_bytes: {max_memory_allocated_bytes}MiB or {max_memory_allocated_percentage}%") - return + + if profile is not None: + print(f"Saving profile under {profile}") + profiler_runner(profile, mask_generator.generate, example_image) + + if dry: + return app = FastAPI() @@ -133,6 +206,25 @@ def main(checkpoint_path, fast=False, furious=False, benchmark=False, verbose=Fa allow_methods=["*"], allow_headers=["*"], ) + + @app.post("/upload_rle") + async def upload_rle(image: UploadFile = File(...)): + # Save the uploaded image to a temporary location + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{image.filename}") + with open(temp_file.name, "wb") as b: + shutil.copyfileobj(image.file, b) + + # Read the image back into memory to send as response + example_image = cv2.imread(temp_file.name) + example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB) + t = time.time() + with torch.backends.cuda.sdp_kernel(enable_cudnn=True): + masks = mask_generator.generate(example_image) + print(f"Took {time.time() - t} to generate a mask for input image.") + ret_data = {} + for mask_id in range(len(masks)): + ret_data[f"mask_{mask_id}"] = masks[mask_id]["segmentation"] + return ret_data @app.post("/upload") async def upload_image(image: UploadFile = File(...)): @@ -143,6 +235,7 @@ async def upload_image(image: UploadFile = File(...)): # Read the image back into memory to send as response example_image = cv2.imread(temp_file.name) + example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB) t = time.time() with torch.backends.cuda.sdp_kernel(enable_cudnn=True): masks = mask_generator.generate(example_image) @@ -150,6 +243,8 @@ async def upload_image(image: UploadFile = File(...)): # Save an example plt.figure(figsize=(example_image.shape[1]/100., example_image.shape[0]/100.), dpi=100) plt.imshow(example_image) + for i in range(len(masks)): + masks[i]["segmentation"] = rle_to_mask(masks[i]["segmentation"]) show_anns(masks) plt.axis('off') plt.tight_layout() @@ -163,7 +258,7 @@ async def upload_image(image: UploadFile = File(...)): return StreamingResponse(BytesIO(image_data), media_type="image/png") - uvicorn.run(app, host="127.0.0.1", port=5000, log_level="info") + uvicorn.run(app, host=host, port=port, log_level="info") if __name__ == "__main__": fire.Fire(main) diff --git a/torchao/_models/sam2/automatic_mask_generator.py b/torchao/_models/sam2/automatic_mask_generator.py index 3f94b6618..c1a53408f 100644 --- a/torchao/_models/sam2/automatic_mask_generator.py +++ b/torchao/_models/sam2/automatic_mask_generator.py @@ -260,14 +260,16 @@ def _process_crop( x0, y0, x1, y1 = crop_box cropped_im = image[y0:y1, x0:x1, :] cropped_im_size = cropped_im.shape[:2] - self.predictor.set_image(cropped_im) + with torch.autograd.profiler.record_function("set_image"): + self.predictor.set_image(cropped_im) # Get points for this crop points_scale = np.array(cropped_im_size)[None, ::-1] points_for_image = self.point_grids[crop_layer_idx] * points_scale # Generate masks for this crop in batches - data = MaskData() + # data = MaskData() + data = None points_per_batch = self.points_per_batch if self.points_per_batch is None: points_per_batch = len(points_for_image) @@ -275,23 +277,33 @@ def _process_crop( batch_data = self._process_batch( points, cropped_im_size, crop_box, orig_size, normalize=True ) - data.cat(batch_data) - del batch_data + with torch.autograd.profiler.record_function("data.cat"): + if data is None: + data = batch_data + else: + data.cat(batch_data) + del batch_data self.predictor.reset_predictor() - # Remove duplicates within this crop. - keep_by_nms = batched_nms( - data["boxes"].float(), - data["iou_preds"], - torch.zeros_like(data["boxes"][:, 0]), # categories - iou_threshold=self.box_nms_thresh, - ) - data.filter(keep_by_nms) + with torch.autograd.profiler.record_function("batched_nms"): + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) - # Return to the original image frame - data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) - data["points"] = uncrop_points(data["points"], crop_box) - data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + with torch.autograd.profiler.record_function("filter"): + data.filter(keep_by_nms) + + with torch.autograd.profiler.record_function("uncrop_boxes_xyxy"): + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + with torch.autograd.profiler.record_function("uncrop_points"): + data["points"] = uncrop_points(data["points"], crop_box) + with torch.autograd.profiler.record_function("crop_boxes"): + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) return data @@ -315,35 +327,40 @@ def _process_batch( in_labels = torch.ones( in_points.shape[0], dtype=torch.int, device=in_points.device ) - masks, iou_preds, low_res_masks = self.predictor._predict( - in_points[:, None, :], - in_labels[:, None], - multimask_output=self.multimask_output, - return_logits=True, - ) + with torch.autograd.profiler.record_function("_predict"): + masks, iou_preds, low_res_masks = self.predictor._predict( + in_points[:, None, :], + in_labels[:, None], + multimask_output=self.multimask_output, + return_logits=True, + ) # Serialize predictions and store in MaskData - data = MaskData( - masks=masks.flatten(0, 1), - iou_preds=iou_preds.flatten(0, 1), - points=points.repeat_interleave(masks.shape[1], dim=0), - low_res_masks=low_res_masks.flatten(0, 1), - ) + with torch.autograd.profiler.record_function("MaskData"): + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=points.repeat_interleave(masks.shape[1], dim=0), + low_res_masks=low_res_masks.flatten(0, 1), + ) del masks if not self.use_m2m: - # Filter by predicted IoU - if self.pred_iou_thresh > 0.0: - keep_mask = data["iou_preds"] > self.pred_iou_thresh - data.filter(keep_mask) - - # Calculate and filter by stability score - data["stability_score"] = calculate_stability_score( - data["masks"], self.mask_threshold, self.stability_score_offset - ) - if self.stability_score_thresh > 0.0: - keep_mask = data["stability_score"] >= self.stability_score_thresh - data.filter(keep_mask) + with torch.autograd.profiler.record_function("thresh and filter"): + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + with torch.autograd.profiler.record_function("calculate_stability_score"): + # Calculate and filter by stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + with torch.autograd.profiler.record_function("stability_score_thresh"): + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) else: # One step refinement using previous mask predictions in_points = self.predictor._transforms.transform_coords( @@ -369,21 +386,26 @@ def _process_batch( keep_mask = data["stability_score"] >= self.stability_score_thresh data.filter(keep_mask) - # Threshold masks and calculate boxes - data["masks"] = data["masks"] > self.mask_threshold - data["boxes"] = batched_mask_to_box(data["masks"]) + with torch.autograd.profiler.record_function("Threshold masks and calculate boxes"): + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) - # Filter boxes that touch crop boundaries - keep_mask = ~is_box_near_crop_edge( - data["boxes"], crop_box, [0, 0, orig_w, orig_h] - ) - # if not torch.all(keep_mask): - data.filter(keep_mask) + with torch.autograd.profiler.record_function("is_box_near_crop_edge"): + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge( + data["boxes"], crop_box, [0, 0, orig_w, orig_h] + ) + + with torch.autograd.profiler.record_function("filter(keep_mask)"): + # if not torch.all(keep_mask): + data.filter(keep_mask) - # Compress to RLE - data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) - data["rles"] = mask_to_rle_pytorch_2(data["masks"]) - del data["masks"] + with torch.autograd.profiler.record_function("uncrop_masks"): + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch_2(data["masks"]) + del data["masks"] return data diff --git a/torchao/_models/sam2/modeling/sam/mask_decoder.py b/torchao/_models/sam2/modeling/sam/mask_decoder.py index f14fec57d..730ad97a9 100644 --- a/torchao/_models/sam2/modeling/sam/mask_decoder.py +++ b/torchao/_models/sam2/modeling/sam/mask_decoder.py @@ -209,8 +209,9 @@ def predict_masks( pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) b, c, h, w = src.shape - # Run the transformer - hs, src = self.transformer(src, pos_src, tokens) + with torch.autograd.profiler.record_function("self.transformer"): + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, s, :] mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] diff --git a/torchao/_models/sam2/sam2_image_predictor.py b/torchao/_models/sam2/sam2_image_predictor.py index 56c3a2a51..630364395 100644 --- a/torchao/_models/sam2/sam2_image_predictor.py +++ b/torchao/_models/sam2/sam2_image_predictor.py @@ -114,7 +114,8 @@ def set_image( len(input_image.shape) == 4 and input_image.shape[1] == 3 ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" logging.info("Computing image embeddings for the provided image...") - backbone_out = self.model.forward_image(input_image) + with torch.autograd.profiler.record_function("forward_image"): + backbone_out = self.model.forward_image(input_image) _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos if self.model.directly_add_no_mem_embed: @@ -403,11 +404,12 @@ def _predict( else: concat_points = (box_coords, box_labels) - sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( - points=concat_points, - boxes=None, - masks=mask_input, - ) + with torch.autograd.profiler.record_function("self.model.sam_prompt_encoder"): + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=concat_points, + boxes=None, + masks=mask_input, + ) # Predict masks batched_mode = ( @@ -417,20 +419,22 @@ def _predict( feat_level[img_idx].unsqueeze(0) for feat_level in self._features["high_res_feats"] ] - low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( - image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), - image_pe=self.model.sam_prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - repeat_image=batched_mode, - high_res_features=high_res_features, - ) + with torch.autograd.profiler.record_function("self.model.sam_mask_decoder"): + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( + image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) - # Upscale the masks to the original image resolution - masks = self._transforms.postprocess_masks( - low_res_masks, self._orig_hw[img_idx] - ) + with torch.autograd.profiler.record_function("self._transforms.postprocess_masks"): + # Upscale the masks to the original image resolution + masks = self._transforms.postprocess_masks( + low_res_masks, self._orig_hw[img_idx] + ) low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) if not return_logits: masks = masks > self.mask_threshold