-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathmain.py
127 lines (96 loc) · 4.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from ultralytics import YOLO
import numpy as np
import cv2
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
import matplotlib.pyplot as plt
def yolov8_detection(model, image_path):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
results = model(image, stream=True) # generator of Results objects
for result in results:
boxes = result.boxes # Boxes object for bbox outputs
bbox=boxes.xyxy.tolist()[0]
return bbox, image
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
image_path = '/content/Cigaretee_i_Weapon_Pragati137.jpg'
model=YOLO('/content/best.pt')
yolov8_boxex, image = yolov8_detection(model, image_path)
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
input_box = np.array(yolov8_boxex)
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
for i, mask in enumerate(masks):
# Convert the mask to a binary image
#binary_mask = mask.cpu().numpy().squeeze().astype(np.uint8)
binary_mask = torch.from_numpy(masks).squeeze().numpy().astype(np.uint8)
# Find the contours of the mask
contours, hierarchy = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Get the largest contour based on area
largest_contour = max(contours, key=cv2.contourArea)
# Get the new bounding box
bbox = [int(x) for x in cv2.boundingRect(largest_contour)]
# Get the segmentation mask for object
segmentation = largest_contour.flatten().tolist()
# Write bounding boxes to file in YOLO format
with open('bounding_boxe_NEW_137.txt', 'w') as f:
for contour in contours:
# Get the bounding box coordinates of the contour
x, y, w, h = cv2.boundingRect(contour)
# Convert the coordinates to YOLO format and write to file
f.write('0 {:.6f} {:.6f} {:.6f} {:.6f}\n'.format((x+w/2)/image.shape[1], (y+h/2)/image.shape[0], w/image.shape[1], h/image.shape[0]))
mask=segmentation
# load the image
#width, height = image_path.size
img = Image.open(image_path)
width, height = img.size
# convert mask to numpy array of shape (N,2)
mask = np.array(mask).reshape(-1,2)
# normalize the pixel coordinates
mask_norm = mask / np.array([width, height])
# compute the bounding box
xmin, ymin = mask_norm.min(axis=0)
xmax, ymax = mask_norm.max(axis=0)
bbox_norm = np.array([xmin, ymin, xmax, ymax])
# concatenate bbox and mask to obtain YOLO format
yolo = np.concatenate([bbox_norm, mask_norm.reshape(-1)])
# write the yolo values to a text file
with open('yolo_cigaNEW_137_MASKw.txt', 'w') as f:
for val in yolo:
f.write("{:.6f} ".format(val))
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()
# Print the bounding box and segmentation mask
print("Bounding box:", bbox)
print("Segmentation mask:", segmentation)