Skip to content

Commit

Permalink
Update plotting.py
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Jun 26, 2024
1 parent 1a0c821 commit 0f393ef
Showing 1 changed file with 55 additions and 29 deletions.
84 changes: 55 additions & 29 deletions ultralytics/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import warnings
from pathlib import Path
from typing import Union, Optional, List, Dict, Callable

import cv2
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -579,7 +580,8 @@ def display_objects_labels(self, im0, text, txt_color, bg_color, x_center, y_cen

def display_analytics(self, im0, text, txt_color, bg_color, margin):
"""
Display the overall statistics for parking lots
Display the overall statistics for parking lots.
Args:
im0 (ndarray): inference image
text (dict): labels dictionary
Expand Down Expand Up @@ -661,7 +663,7 @@ def plot_angle_and_count_and_stage(
angle_text (str): angle value for workout monitoring
count_text (str): counts value for workout monitoring
stage_text (str): stage decision for workout monitoring
center_kpt (int): centroid pose index for workout monitoring
center_kpt (list): centroid pose index for workout monitoring
color (tuple): text background color for workout monitoring
txt_color (tuple): text foreground color for workout monitoring
"""
Expand Down Expand Up @@ -917,23 +919,49 @@ def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False,

@threaded
def plot_images(
images,
batch_idx,
cls,
bboxes=np.zeros(0, dtype=np.float32),
confs=None,
masks=np.zeros(0, dtype=np.uint8),
kpts=np.zeros((0, 51), dtype=np.float32),
paths=None,
fname="images.jpg",
names=None,
on_plot=None,
max_size=1920, # max image size
max_subplots=16,
save=True,
conf_thres=0.25,
):
"""Plot image grid with labels."""
images: Union[torch.Tensor, np.ndarray],
batch_idx: Union[torch.Tensor, np.ndarray],
cls: Union[torch.Tensor, np.ndarray],
bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
paths: Optional[List[str]] = None,
fname: str = "images.jpg",
names: Optional[Dict[int, str]] = None,
on_plot: Optional[Callable] = None,
max_size: int = 1920,
max_subplots: int = 16,
save: bool = True,
conf_thres: float = 0.25,
) -> Optional[np.ndarray]:
"""
Plot image grid with labels, bounding boxes, masks, and keypoints.
Args:
images: Batch of images to plot. Shape: (batch_size, channels, height, width).
batch_idx: Batch indices for each detection. Shape: (num_detections,).
cls: Class labels for each detection. Shape: (num_detections,).
bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.
confs: Confidence scores for each detection. Shape: (num_detections,).
masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).
kpts: Keypoints for each detection. Shape: (num_detections, 51).
paths: List of file paths for each image in the batch.
fname: Output filename for the plotted image grid.
names: Dictionary mapping class indices to class names.
on_plot: Optional callback function to be called after saving the plot.
max_size: Maximum size of the output image grid.
max_subplots: Maximum number of subplots in the image grid.
save: Whether to save the plotted image grid to a file.
conf_thres: Confidence threshold for displaying detections.
Returns:
np.ndarray: Plotted image grid as a numpy array if save is False, None otherwise.
Note:
This function supports both tensor and numpy array inputs. It will automatically
convert tensor inputs to numpy arrays for processing.
"""
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
if isinstance(cls, torch.Tensor):
Expand Down Expand Up @@ -1166,6 +1194,12 @@ def plot_tune_results(csv_file="tune_results.csv"):
import pandas as pd # scope for faster 'import ultralytics'
from scipy.ndimage import gaussian_filter1d

def _save_one_file(file):

Check warning on line 1197 in ultralytics/utils/plotting.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/utils/plotting.py#L1197

Added line #L1197 was not covered by tests
"""Save one matplotlib plot to 'file'."""
plt.savefig(file, dpi=200)
plt.close()
LOGGER.info(f"Saved {file}")

Check warning on line 1201 in ultralytics/utils/plotting.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/utils/plotting.py#L1199-L1201

Added lines #L1199 - L1201 were not covered by tests

# Scatter plots for each hyperparameter
csv_file = Path(csv_file)
data = pd.read_csv(csv_file)
Expand All @@ -1186,11 +1220,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
if i % n != 0:
plt.yticks([])

file = csv_file.with_name("tune_scatter_plots.png") # filename
plt.savefig(file, dpi=200)
plt.close()
LOGGER.info(f"Saved {file}")
_save_one_file(csv_file.with_name("tune_scatter_plots.png"))

Check warning on line 1223 in ultralytics/utils/plotting.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/utils/plotting.py#L1223

Added line #L1223 was not covered by tests

# Fitness vs iteration
x = range(1, len(fitness) + 1)
Expand All @@ -1202,11 +1232,7 @@ def plot_tune_results(csv_file="tune_results.csv"):
plt.ylabel("Fitness")
plt.grid(True)
plt.legend()

file = csv_file.with_name("tune_fitness.png") # filename
plt.savefig(file, dpi=200)
plt.close()
LOGGER.info(f"Saved {file}")
_save_one_file(csv_file.with_name("tune_fitness.png"))

Check warning on line 1235 in ultralytics/utils/plotting.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/utils/plotting.py#L1235

Added line #L1235 was not covered by tests


def output_to_target(output, max_det=300):
Expand Down

0 comments on commit 0f393ef

Please sign in to comment.