Skip to content

Commit

Permalink
added torchvision det, seg n pose examples to repo
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikel Broström committed Sep 30, 2024
1 parent 74b324f commit a0846d7
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 37 deletions.
34 changes: 16 additions & 18 deletions examples/torchvision_pose_boxmot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,30 +71,28 @@
" inds = tracks[:, 7].astype('int') # Get track indices as int\n",
"\n",
" # Use the indices to match tracks with keypoints\n",
" if len(keypoints) > 0:\n",
" keypoints = [keypoints[i] for i in inds if i < len(keypoints)] # Reorder keypoints to match the tracks\n",
" keypoints = [keypoints[i] for i in inds if i < len(keypoints)] # Reorder keypoints to match the tracks\n",
"\n",
" # Draw keypoints on the image\n",
" for i, kp in enumerate(keypoints):\n",
" # Get the color of the corresponding track\n",
" color = get_color(tracks[i, 4]) # tracks[i, 4] is the track_id\n",
" # Draw bounding boxes and keypoints in the same loop\n",
" for i, track in enumerate(tracks):\n",
" x1, y1, x2, y2, track_id, conf, cls = track[:7].astype('int')\n",
" color = get_color(track_id)\n",
"\n",
" # Draw bounding box with unique color\n",
" cv2.rectangle(im, (x1, y1), (x2, y2), color, 2)\n",
"\n",
" # Add text with ID, confidence, and class\n",
" cv2.putText(im, f'ID: {track_id}, Conf: {conf:.2f}, Class: {cls}', \n",
" (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)\n",
"\n",
" # Draw keypoints for the corresponding track\n",
" if i < len(keypoints):\n",
" kp = keypoints[i]\n",
" for point in kp:\n",
" x, y, confidence = int(point[0]), int(point[1]), point[2]\n",
" if confidence > 0.5: # Only draw keypoints with confidence > 0.5\n",
" cv2.circle(im, (x, y), 3, color, -1) # Draw keypoints in the color of the corresponding track\n",
"\n",
" # Show the image (optional: draw bounding boxes and keypoints)\n",
" for track in tracks:\n",
" x1, y1, x2, y2, track_id, conf, cls = track[:7].astype('int')\n",
" color = get_color(track_id)\n",
"\n",
" # Draw bounding box with unique color\n",
" cv2.rectangle(im, (x1, y1), (x2, y2), color, 2)\n",
"\n",
" # Add text with ID, confidence, and class\n",
" cv2.putText(im, f'ID: {track_id}, Conf: {conf:.2f}, Class: {cls}', \n",
" (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)\n",
"\n",
" # Display the image\n",
" cv2.imshow('Pose Tracking', im)\n",
"\n",
Expand Down
36 changes: 17 additions & 19 deletions examples/torchvision_seg_boxmot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,38 +67,36 @@
" # Update tracker with detections and image\n",
" tracks = tracker.update(dets, im) # M x (x, y, x, y, id, conf, cls, ind)\n",
"\n",
" # Draw segmentation masks and bounding boxes\n",
" # Draw segmentation masks and bounding boxes in a single loop\n",
" if len(tracks) > 0:\n",
" inds = tracks[:, 7].astype('int') # Get track indices as int\n",
"\n",
" # Use the indices to match tracks with masks\n",
" if len(masks) > 0:\n",
" masks = [masks[i] for i in inds if i < len(masks)] # Reorder masks to match the tracks\n",
"\n",
" # Draw masks on the image\n",
" for track, mask in zip(tracks, masks):\n",
" track_id = int(track[4]) # Extract track ID\n",
" color = get_color(track_id) # Use unique color for each track\n",
" \n",
" # Iterate over tracks and corresponding masks to draw them together\n",
" for track, mask in zip(tracks, masks):\n",
" track_id = int(track[4]) # Extract track ID\n",
" color = get_color(track_id) # Use unique color for each track\n",
" \n",
" # Draw the segmentation mask on the image\n",
" if mask is not None:\n",
" # Binarize the mask\n",
" mask = (mask > 0.5).astype(np.uint8)\n",
" \n",
" # Blend mask color with the image\n",
" im[mask == 1] = im[mask == 1] * 0.5 + np.array(color) * 0.5\n",
"\n",
" # Draw bounding boxes and tracking info\n",
" for track in tracks:\n",
" x1, y1, x2, y2, track_id = track[:5].astype('int')\n",
" color = get_color(track_id)\n",
" \n",
" # Draw bounding box with unique color\n",
" cv2.rectangle(im, (x1, y1), (x2, y2), color, 2)\n",
" \n",
" # Add text with ID, confidence, and class\n",
" conf = track[5]\n",
" cls = track[6]\n",
" cv2.putText(im, f'ID: {track_id}, Conf: {conf:.2f}, Class: {cls}', \n",
" (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)\n",
" # Draw the bounding box\n",
" x1, y1, x2, y2 = track[:4].astype('int')\n",
" cv2.rectangle(im, (x1, y1), (x2, y2), color, 2)\n",
" \n",
" # Add text with ID, confidence, and class\n",
" conf = track[5]\n",
" cls = track[6]\n",
" cv2.putText(im, f'ID: {track_id}, Conf: {conf:.2f}, Class: {cls}', \n",
" (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)\n",
"\n",
" # Display the image\n",
" cv2.imshow('Segmentation Tracking', im)\n",
Expand Down
203 changes: 203 additions & 0 deletions examples/torchvision_tiled_det_boxmot.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2024-09-30 23:12:25.295\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mboxmot.utils.torch_utils\u001b[0m:\u001b[36mselect_device\u001b[0m:\u001b[36m52\u001b[0m - \u001b[1mYolo Tracking v11.0.2 🚀 Python-3.11.5 torch-2.2.2CPU\u001b[0m\n",
"\u001b[32m2024-09-30 23:12:25.316\u001b[0m | \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mboxmot.appearance.reid_model_factory\u001b[0m:\u001b[36mload_pretrained_weights\u001b[0m:\u001b[36m183\u001b[0m - \u001b[32m\u001b[1mLoaded pretrained weights from osnet_x0_25_msmt17.pt\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Performing prediction on 60 slices.\n"
]
},
{
"ename": "TypeError",
"evalue": "TorchvisionDetectionModel.convert_original_predictions() got an unexpected keyword argument 'full_shape'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb Cell 1\u001b[0m line \u001b[0;36m9\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb#W0sZmlsZQ%3D%3D?line=89'>90</a>\u001b[0m \u001b[39mbreak\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb#W0sZmlsZQ%3D%3D?line=91'>92</a>\u001b[0m \u001b[39m# Get sliced predictions using SAHI's get_sliced_prediction\u001b[39;00m\n\u001b[0;32m---> <a href='vscode-notebook-cell:/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb#W0sZmlsZQ%3D%3D?line=92'>93</a>\u001b[0m result \u001b[39m=\u001b[39m get_sliced_prediction(\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb#W0sZmlsZQ%3D%3D?line=93'>94</a>\u001b[0m frame,\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb#W0sZmlsZQ%3D%3D?line=94'>95</a>\u001b[0m detection_model,\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb#W0sZmlsZQ%3D%3D?line=95'>96</a>\u001b[0m slice_height\u001b[39m=\u001b[39;49m\u001b[39m256\u001b[39;49m,\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb#W0sZmlsZQ%3D%3D?line=96'>97</a>\u001b[0m slice_width\u001b[39m=\u001b[39;49m\u001b[39m256\u001b[39;49m,\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb#W0sZmlsZQ%3D%3D?line=97'>98</a>\u001b[0m overlap_height_ratio\u001b[39m=\u001b[39;49m\u001b[39m0.2\u001b[39;49m,\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb#W0sZmlsZQ%3D%3D?line=98'>99</a>\u001b[0m overlap_width_ratio\u001b[39m=\u001b[39;49m\u001b[39m0.2\u001b[39;49m\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb#W0sZmlsZQ%3D%3D?line=99'>100</a>\u001b[0m )\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb#W0sZmlsZQ%3D%3D?line=101'>102</a>\u001b[0m \u001b[39m# Extract detections from result\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/mikel.brostrom/boxmot/examples/torchvision_tiled_det_boxmot.ipynb#W0sZmlsZQ%3D%3D?line=102'>103</a>\u001b[0m num_predictions \u001b[39m=\u001b[39m \u001b[39mlen\u001b[39m(result\u001b[39m.\u001b[39mobject_prediction_list)\n",
"File \u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/boxmot-YDNZdsaB-py3.11/lib/python3.11/site-packages/sahi/predict.py:249\u001b[0m, in \u001b[0;36mget_sliced_prediction\u001b[0;34m(image, detection_model, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio, perform_standard_pred, postprocess_type, postprocess_match_metric, postprocess_match_threshold, postprocess_class_agnostic, verbose, merge_buffer_length, auto_slice_resolution, slice_export_prefix, slice_dir)\u001b[0m\n\u001b[1;32m 247\u001b[0m shift_amount_list\u001b[39m.\u001b[39mappend(slice_image_result\u001b[39m.\u001b[39mstarting_pixels[group_ind \u001b[39m*\u001b[39m num_batch \u001b[39m+\u001b[39m image_ind])\n\u001b[1;32m 248\u001b[0m \u001b[39m# perform batch prediction\u001b[39;00m\n\u001b[0;32m--> 249\u001b[0m prediction_result \u001b[39m=\u001b[39m get_prediction(\n\u001b[1;32m 250\u001b[0m image\u001b[39m=\u001b[39;49mimage_list[\u001b[39m0\u001b[39;49m],\n\u001b[1;32m 251\u001b[0m detection_model\u001b[39m=\u001b[39;49mdetection_model,\n\u001b[1;32m 252\u001b[0m shift_amount\u001b[39m=\u001b[39;49mshift_amount_list[\u001b[39m0\u001b[39;49m],\n\u001b[1;32m 253\u001b[0m full_shape\u001b[39m=\u001b[39;49m[\n\u001b[1;32m 254\u001b[0m slice_image_result\u001b[39m.\u001b[39;49moriginal_image_height,\n\u001b[1;32m 255\u001b[0m slice_image_result\u001b[39m.\u001b[39;49moriginal_image_width,\n\u001b[1;32m 256\u001b[0m ],\n\u001b[1;32m 257\u001b[0m )\n\u001b[1;32m 258\u001b[0m \u001b[39m# convert sliced predictions to full predictions\u001b[39;00m\n\u001b[1;32m 259\u001b[0m \u001b[39mfor\u001b[39;00m object_prediction \u001b[39min\u001b[39;00m prediction_result\u001b[39m.\u001b[39mobject_prediction_list:\n",
"File \u001b[0;32m~/Library/Caches/pypoetry/virtualenvs/boxmot-YDNZdsaB-py3.11/lib/python3.11/site-packages/sahi/predict.py:100\u001b[0m, in \u001b[0;36mget_prediction\u001b[0;34m(image, detection_model, shift_amount, full_shape, postprocess, verbose)\u001b[0m\n\u001b[1;32m 98\u001b[0m time_start \u001b[39m=\u001b[39m time\u001b[39m.\u001b[39mtime()\n\u001b[1;32m 99\u001b[0m \u001b[39m# works only with 1 batch\u001b[39;00m\n\u001b[0;32m--> 100\u001b[0m detection_model\u001b[39m.\u001b[39;49mconvert_original_predictions(\n\u001b[1;32m 101\u001b[0m shift_amount\u001b[39m=\u001b[39;49mshift_amount,\n\u001b[1;32m 102\u001b[0m full_shape\u001b[39m=\u001b[39;49mfull_shape,\n\u001b[1;32m 103\u001b[0m )\n\u001b[1;32m 104\u001b[0m object_prediction_list: List[ObjectPrediction] \u001b[39m=\u001b[39m detection_model\u001b[39m.\u001b[39mobject_prediction_list\n\u001b[1;32m 106\u001b[0m \u001b[39m# postprocess matching predictions\u001b[39;00m\n",
"\u001b[0;31mTypeError\u001b[0m: TorchvisionDetectionModel.convert_original_predictions() got an unexpected keyword argument 'full_shape'"
]
}
],
"source": [
"import torch\n",
"import torchvision\n",
"import cv2\n",
"import numpy as np\n",
"from pathlib import Path\n",
"from boxmot import BoTSORT\n",
"from sahi.models.base import DetectionModel\n",
"from sahi.predict import get_sliced_prediction\n",
"\n",
"# Define a custom detection model that is compatible with SAHI\n",
"class TorchvisionDetectionModel(DetectionModel):\n",
" def __init__(self, model, confidence_threshold=0.5, device='cpu'):\n",
" super().__init__(confidence_threshold=confidence_threshold, device=device)\n",
" self.model = model.to(device)\n",
" self.device = device\n",
" self.confidence_threshold = confidence_threshold\n",
" self.model.eval()\n",
"\n",
" def load_model(self):\n",
" # The model is already loaded during initialization, so we just set it\n",
" self.set_model(self.model)\n",
"\n",
" def set_model(self, model):\n",
" # Set the model directly\n",
" self.model = model\n",
"\n",
" def perform_inference(self, image):\n",
" # Convert the image to a tensor and move to the specified device\n",
" image_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).to(self.device)\n",
"\n",
" # Perform detection using the model\n",
" with torch.no_grad():\n",
" outputs = self.model(image_tensor)[0]\n",
"\n",
" return outputs\n",
"\n",
" def convert_original_predictions(self, outputs, original_image, shift_amount=None, full_shape=None):\n",
" # Convert the model output to the format expected by SAHI\n",
" detection_list = []\n",
" for i, score in enumerate(outputs['scores']):\n",
" if score >= self.confidence_threshold:\n",
" x1, y1, x2, y2 = outputs['boxes'][i].cpu().numpy() # Bounding box coordinates\n",
"\n",
" # Apply shift if shift_amount is provided\n",
" if shift_amount is not None:\n",
" x1 += shift_amount[1]\n",
" y1 += shift_amount[0]\n",
" x2 += shift_amount[1]\n",
" y2 += shift_amount[0]\n",
"\n",
" conf = score.item() # Confidence score\n",
" label = outputs['labels'][i].item() # Class label\n",
"\n",
" detection_list.append({\n",
" 'bbox': [x1, y1, x2, y2],\n",
" 'score': conf,\n",
" 'category_id': label\n",
" })\n",
"\n",
" return detection_list\n",
"\n",
"# Load a pre-trained Faster R-CNN model from torchvision\n",
"device = torch.device('cpu') # Use 'cuda' if you have a GPU\n",
"detector = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)\n",
"\n",
"# Wrap the torchvision model in the custom SAHI-compatible detection model\n",
"detection_model = TorchvisionDetectionModel(model=detector, device=device, confidence_threshold=0.5)\n",
"\n",
"# Initialize BoTSORT Tracker\n",
"tracker = BoTSORT(\n",
" reid_weights=Path('osnet_x0_25_msmt17.pt'), # Path to ReID model\n",
" device=device, # Use CPU for inference\n",
" half=False\n",
")\n",
"\n",
"# Open the video file (use 0 for webcam or provide a video file path)\n",
"vid = cv2.VideoCapture(0)\n",
"\n",
"# Function to generate a unique color for each track ID\n",
"def get_color(track_id):\n",
" np.random.seed(int(track_id))\n",
" return tuple(np.random.randint(0, 255, 3).tolist())\n",
"\n",
"while True:\n",
" # Capture frame-by-frame\n",
" ret, frame = vid.read()\n",
"\n",
" # If ret is False, it means we have reached the end of the video or there's an error\n",
" if not ret:\n",
" break\n",
"\n",
" # Get sliced predictions using SAHI's get_sliced_prediction\n",
" result = get_sliced_prediction(\n",
" frame,\n",
" detection_model,\n",
" slice_height=256,\n",
" slice_width=256,\n",
" overlap_height_ratio=0.2,\n",
" overlap_width_ratio=0.2\n",
" )\n",
"\n",
" # Extract detections from result\n",
" num_predictions = len(result.object_prediction_list)\n",
" dets = np.zeros([num_predictions, 6], dtype=np.float32)\n",
" for ind, object_prediction in enumerate(result.object_prediction_list):\n",
" bbox = object_prediction.bbox.to_xyxy()\n",
" dets[ind, :4] = np.array(bbox, dtype=np.float32)\n",
" dets[ind, 4] = object_prediction.score.value\n",
" dets[ind, 5] = object_prediction.category.id\n",
"\n",
" # Update the tracker with the detections\n",
" tracks = tracker.update(dets, frame) # (M x (x1, y1, x2, y2, id, conf, cls, ind))\n",
"\n",
" # Draw the tracking results on the image\n",
" for track in tracks:\n",
" x1, y1, x2, y2, track_id, conf, cls = track[:7].astype('int')\n",
" color = get_color(track_id)\n",
"\n",
" # Draw bounding box with unique color\n",
" cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)\n",
"\n",
" # Add text with ID, confidence, and class\n",
" cv2.putText(frame, f'ID: {track_id}, Conf: {conf:.2f}, Class: {cls}', \n",
" (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)\n",
"\n",
" # Display the frame with tracking results\n",
" cv2.imshow('BoXMOT + Torchvision with Tiled Inference', frame)\n",
"\n",
" # Simulate wait for key press to continue, press 'q' to exit\n",
" key = cv2.waitKey(1) & 0xFF\n",
" if key == ord(' ') or key == ord('q'):\n",
" break\n",
"\n",
"# Release the video capture and close all OpenCV windows\n",
"vid.release()\n",
"cv2.destroyAllWindows()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "boxmot-YDNZdsaB-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit a0846d7

Please sign in to comment.