|
| 1 | +import argparse |
| 2 | +import os |
| 3 | + |
| 4 | +import cv2 |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +import mindspore as ms |
| 8 | + |
| 9 | +from segment_anything.build_sam import sam_model_registry |
| 10 | +from segment_anything.dataset.transform import TransformPipeline, ImageNorm |
| 11 | +from segment_anything.utils.transforms import ResizeLongestSide |
| 12 | +import matplotlib.pyplot as plt |
| 13 | +import time |
| 14 | + |
| 15 | +from use_sam_with_promts import show_mask, show_box |
| 16 | + |
| 17 | + |
| 18 | +class Timer: |
| 19 | + def __init__(self, name=''): |
| 20 | + self.name = name |
| 21 | + self.start = 0.0 |
| 22 | + self.end = 0.0 |
| 23 | + |
| 24 | + def __enter__(self): |
| 25 | + self.start = time.time() |
| 26 | + |
| 27 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 28 | + self.end = time.time() |
| 29 | + print(f'{self.name} cost time {self.end - self.start:.3f}') |
| 30 | + |
| 31 | + |
| 32 | +class ImageResizeAndPad: |
| 33 | + |
| 34 | + def __init__(self, target_size): |
| 35 | + """ |
| 36 | + Args: |
| 37 | + target_size (int): target size of model input (1024 in sam) |
| 38 | + """ |
| 39 | + self.target_size = target_size |
| 40 | + self.transform = ResizeLongestSide(target_size) |
| 41 | + |
| 42 | + def __call__(self, result_dict): |
| 43 | + """ |
| 44 | + Resize input to the long size and then pad it to the model input size (1024*1024 in sam). |
| 45 | + Pad masks and boxes to a fixed length for graph mode |
| 46 | + Required keys: image, masks, boxes |
| 47 | + Update keys: image, masks, boxes |
| 48 | + Add keys: |
| 49 | + origin_hw (np.array): array with shape (4), represents original image height, width |
| 50 | + and resized height, width, respectively. This array record the trace of image shape transformation |
| 51 | + and is used for visualization. |
| 52 | + image_pad_area (Tuple): image padding area in h and w direction, in the format of |
| 53 | + ((pad_h_left, pad_h_right), (pad_w_left, pad_w_right)) |
| 54 | + """ |
| 55 | + |
| 56 | + image = result_dict['image'] |
| 57 | + boxes = result_dict['boxes'] |
| 58 | + |
| 59 | + og_h, og_w, _ = image.shape |
| 60 | + image = self.transform.apply_image(image) |
| 61 | + resized_h, resized_w, _ = image.shape |
| 62 | + |
| 63 | + # Pad image and masks to the model input |
| 64 | + h, w, c = image.shape |
| 65 | + max_dim = max(h, w) # long side length |
| 66 | + assert max_dim == self.target_size |
| 67 | + # pad 0 to the right and bottom side |
| 68 | + pad_h = max_dim - h |
| 69 | + pad_w = max_dim - w |
| 70 | + img_padding = ((0, pad_h), (0, pad_w), (0, 0)) |
| 71 | + image = np.pad(image, pad_width=img_padding, constant_values=0) # (h, w, c) |
| 72 | + |
| 73 | + # Adjust bounding boxes |
| 74 | + boxes = self.transform.apply_boxes(boxes, (og_h, og_w)).astype(np.float32) |
| 75 | + |
| 76 | + result_dict['origin_hw'] = np.array([og_h, og_w, resized_h, resized_w], np.int32) # record image shape trace for visualization |
| 77 | + result_dict['image'] = image |
| 78 | + result_dict['boxes'] = boxes |
| 79 | + result_dict['image_pad_area'] = img_padding[:2] |
| 80 | + |
| 81 | + return result_dict |
| 82 | + |
| 83 | + |
| 84 | +def infer(args): |
| 85 | + ms.context.set_context(mode=args.mode, device_target=args.device) |
| 86 | + |
| 87 | + # Step1: data preparation |
| 88 | + with Timer('preprocess'): |
| 89 | + transform_list = [ |
| 90 | + ImageResizeAndPad(target_size=1024), |
| 91 | + ImageNorm(), |
| 92 | + ] |
| 93 | + transform_pipeline = TransformPipeline(transform_list) |
| 94 | + |
| 95 | + image_path = args.image_path |
| 96 | + image_np = cv2.imread(image_path) |
| 97 | + image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) |
| 98 | + boxes_np = np.array([[425, 600, 700, 875]]) |
| 99 | + |
| 100 | + transformed = transform_pipeline(dict(image=image_np, boxes=boxes_np)) |
| 101 | + image, boxes, origin_hw = transformed['image'], transformed['boxes'], transformed['origin_hw'] |
| 102 | + image = ms.Tensor(image).unsqueeze(0) # b, 3, 1023 |
| 103 | + boxes = ms.Tensor(boxes).unsqueeze(0) # b, n, 4 |
| 104 | + |
| 105 | + # Step2: inference |
| 106 | + with Timer('model inference'): |
| 107 | + with Timer('load weight and build net'): |
| 108 | + network = sam_model_registry[args.model_type](checkpoint=args.checkpoint) |
| 109 | + ms.amp.auto_mixed_precision(network=network, amp_level=args.amp_level) |
| 110 | + mask_logits = network(image, boxes)[0] # (1, 1, 1024, 1024) |
| 111 | + |
| 112 | + with Timer('Second time inference'): |
| 113 | + mask_logits = network(image, boxes)[0] # (1, 1, 1024, 1024) |
| 114 | + |
| 115 | + # Step3: post-process |
| 116 | + with Timer('post-process'): |
| 117 | + mask_logits = mask_logits.asnumpy()[0, 0] > 0.0 |
| 118 | + mask_logits = mask_logits.astype(np.uint8) |
| 119 | + final_mask = cv2.resize(mask_logits[:origin_hw[2], :origin_hw[3]], tuple((origin_hw[1], origin_hw[0])), |
| 120 | + interpolation=cv2.INTER_CUBIC) |
| 121 | + |
| 122 | + # Step4: visualize |
| 123 | + plt.imshow(image_np) |
| 124 | + show_box(boxes_np[0], plt.gca()) |
| 125 | + show_mask(final_mask, plt.gca()) |
| 126 | + plt.savefig(args.image_path + '_infer.jpg') |
| 127 | + plt.show() |
| 128 | + |
| 129 | + |
| 130 | +if __name__ == '__main__': |
| 131 | + parser = argparse.ArgumentParser(description=("Runs inference on one image")) |
| 132 | + parser.add_argument("--image_path", type=str, default='./images/truck.jpg', help="Path to an input image.") |
| 133 | + parser.add_argument( |
| 134 | + "--model-type", |
| 135 | + type=str, |
| 136 | + default='vit_b', |
| 137 | + help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']", |
| 138 | + ) |
| 139 | + |
| 140 | + parser.add_argument( |
| 141 | + "--checkpoint", |
| 142 | + type=str, |
| 143 | + default='./models/sam_vit_b-35e4849c.ckpt', |
| 144 | + help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", |
| 145 | + ) |
| 146 | + |
| 147 | + parser.add_argument("--device", type=str, default="Ascend", help="The device to run generation on.") |
| 148 | + parser.add_argument("--amp_level", type=str, default="O2", help="auto mixed precision level O0, O2.") |
| 149 | + parser.add_argument("--mode", type=int, default=0, help="MindSpore context mode. 0 for graph, 1 for pynative.") |
| 150 | + |
| 151 | + args = parser.parse_args() |
| 152 | + print(args) |
| 153 | + infer(args) |
0 commit comments