Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 3ba859f

Browse files
author
Mark-ZhouWX
committedSep 19, 2023
add inference one image
1 parent 3ffde59 commit 3ba859f

File tree

2 files changed

+154
-1
lines changed

2 files changed

+154
-1
lines changed
 
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import argparse
2+
import os
3+
4+
import cv2
5+
import numpy as np
6+
7+
import mindspore as ms
8+
9+
from segment_anything.build_sam import sam_model_registry
10+
from segment_anything.dataset.transform import TransformPipeline, ImageNorm
11+
from segment_anything.utils.transforms import ResizeLongestSide
12+
import matplotlib.pyplot as plt
13+
import time
14+
15+
from use_sam_with_promts import show_mask, show_box
16+
17+
18+
class Timer:
19+
def __init__(self, name=''):
20+
self.name = name
21+
self.start = 0.0
22+
self.end = 0.0
23+
24+
def __enter__(self):
25+
self.start = time.time()
26+
27+
def __exit__(self, exc_type, exc_val, exc_tb):
28+
self.end = time.time()
29+
print(f'{self.name} cost time {self.end - self.start:.3f}')
30+
31+
32+
class ImageResizeAndPad:
33+
34+
def __init__(self, target_size):
35+
"""
36+
Args:
37+
target_size (int): target size of model input (1024 in sam)
38+
"""
39+
self.target_size = target_size
40+
self.transform = ResizeLongestSide(target_size)
41+
42+
def __call__(self, result_dict):
43+
"""
44+
Resize input to the long size and then pad it to the model input size (1024*1024 in sam).
45+
Pad masks and boxes to a fixed length for graph mode
46+
Required keys: image, masks, boxes
47+
Update keys: image, masks, boxes
48+
Add keys:
49+
origin_hw (np.array): array with shape (4), represents original image height, width
50+
and resized height, width, respectively. This array record the trace of image shape transformation
51+
and is used for visualization.
52+
image_pad_area (Tuple): image padding area in h and w direction, in the format of
53+
((pad_h_left, pad_h_right), (pad_w_left, pad_w_right))
54+
"""
55+
56+
image = result_dict['image']
57+
boxes = result_dict['boxes']
58+
59+
og_h, og_w, _ = image.shape
60+
image = self.transform.apply_image(image)
61+
resized_h, resized_w, _ = image.shape
62+
63+
# Pad image and masks to the model input
64+
h, w, c = image.shape
65+
max_dim = max(h, w) # long side length
66+
assert max_dim == self.target_size
67+
# pad 0 to the right and bottom side
68+
pad_h = max_dim - h
69+
pad_w = max_dim - w
70+
img_padding = ((0, pad_h), (0, pad_w), (0, 0))
71+
image = np.pad(image, pad_width=img_padding, constant_values=0) # (h, w, c)
72+
73+
# Adjust bounding boxes
74+
boxes = self.transform.apply_boxes(boxes, (og_h, og_w)).astype(np.float32)
75+
76+
result_dict['origin_hw'] = np.array([og_h, og_w, resized_h, resized_w], np.int32) # record image shape trace for visualization
77+
result_dict['image'] = image
78+
result_dict['boxes'] = boxes
79+
result_dict['image_pad_area'] = img_padding[:2]
80+
81+
return result_dict
82+
83+
84+
def infer(args):
85+
ms.context.set_context(mode=args.mode, device_target=args.device)
86+
87+
# Step1: data preparation
88+
with Timer('preprocess'):
89+
transform_list = [
90+
ImageResizeAndPad(target_size=1024),
91+
ImageNorm(),
92+
]
93+
transform_pipeline = TransformPipeline(transform_list)
94+
95+
image_path = args.image_path
96+
image_np = cv2.imread(image_path)
97+
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
98+
boxes_np = np.array([[425, 600, 700, 875]])
99+
100+
transformed = transform_pipeline(dict(image=image_np, boxes=boxes_np))
101+
image, boxes, origin_hw = transformed['image'], transformed['boxes'], transformed['origin_hw']
102+
image = ms.Tensor(image).unsqueeze(0) # b, 3, 1023
103+
boxes = ms.Tensor(boxes).unsqueeze(0) # b, n, 4
104+
105+
# Step2: inference
106+
with Timer('model inference'):
107+
with Timer('load weight and build net'):
108+
network = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
109+
ms.amp.auto_mixed_precision(network=network, amp_level=args.amp_level)
110+
mask_logits = network(image, boxes)[0] # (1, 1, 1024, 1024)
111+
112+
with Timer('Second time inference'):
113+
mask_logits = network(image, boxes)[0] # (1, 1, 1024, 1024)
114+
115+
# Step3: post-process
116+
with Timer('post-process'):
117+
mask_logits = mask_logits.asnumpy()[0, 0] > 0.0
118+
mask_logits = mask_logits.astype(np.uint8)
119+
final_mask = cv2.resize(mask_logits[:origin_hw[2], :origin_hw[3]], tuple((origin_hw[1], origin_hw[0])),
120+
interpolation=cv2.INTER_CUBIC)
121+
122+
# Step4: visualize
123+
plt.imshow(image_np)
124+
show_box(boxes_np[0], plt.gca())
125+
show_mask(final_mask, plt.gca())
126+
plt.savefig(args.image_path + '_infer.jpg')
127+
plt.show()
128+
129+
130+
if __name__ == '__main__':
131+
parser = argparse.ArgumentParser(description=("Runs inference on one image"))
132+
parser.add_argument("--image_path", type=str, default='./images/truck.jpg', help="Path to an input image.")
133+
parser.add_argument(
134+
"--model-type",
135+
type=str,
136+
default='vit_b',
137+
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']",
138+
)
139+
140+
parser.add_argument(
141+
"--checkpoint",
142+
type=str,
143+
default='./models/sam_vit_b-35e4849c.ckpt',
144+
help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
145+
)
146+
147+
parser.add_argument("--device", type=str, default="Ascend", help="The device to run generation on.")
148+
parser.add_argument("--amp_level", type=str, default="O2", help="auto mixed precision level O0, O2.")
149+
parser.add_argument("--mode", type=int, default=0, help="MindSpore context mode. 0 for graph, 1 for pynative.")
150+
151+
args = parser.parse_args()
152+
print(args)
153+
infer(args)

‎research/segment-anything/segment_anything/dataset/transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __call__(self, result_dict):
5252
"""
5353
Norm an image with given mean and std, also adjust transpose the chanel when specified.
5454
55-
Required keys: image
55+
Required keys: image, image_pad_area
5656
Updated keys: image
5757
Added keys:
5858
"""

0 commit comments

Comments
 (0)
Please sign in to comment.