Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance issue for own mAP implementation #677

Closed
tkupek opened this issue Dec 10, 2021 · 7 comments · Fixed by #742
Closed

Performance issue for own mAP implementation #677

tkupek opened this issue Dec 10, 2021 · 7 comments · Fixed by #742
Labels
enhancement New feature or request help wanted Extra attention is needed
Milestone

Comments

@tkupek
Copy link
Contributor

tkupek commented Dec 10, 2021

🐛 Bug

The new mAP implementation that replaced the pycocotools implementation seems to have a performance issue. In my measurements it is 10-15x slower.

These are are some performance measurements I did on a CPU and single GPU (same machine):

New implementation

CPU
Running val metric on 640 samples
Total time: 4.70652174949646
Time per sample 0.007353940233588219

Single GPU
Running val metric on 640 samples
Total time: 33.491135358810425
Time per sample 0.05232989899814129

Running val metric on 64 samples
Total time: 4.064690351486206
Time per sample 0.06351078674197197

Running val metric on 1280 samples
Total time: 61.18194818496704
Time per sample 0.0477983970195055

Preview implementation (pycocotools)

CPU
Running val metric on 640 samples
Total time: 0.386138916015625
Time per sample 0.000603342056274414

Single GPU
Running val metric on 640 samples
Total time: 2.527284622192383
Time per sample 0.003950447216629982 (bearbeitet) 

To Reproduce

Steps to reproduce the behavior:

Run mAP implementation as in the example.

Code sample

Find the measurement code with BoringModel here:

import time

import torch
import torch.distributed
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import Dataset, DataLoader

from torchmetrics.detection.map import MAP

BATCH_SIZE = 32

class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


train = RandomDataset(32, 40 * BATCH_SIZE)
train = DataLoader(train, batch_size=BATCH_SIZE)

val = RandomDataset(32, 40 * BATCH_SIZE)
val = DataLoader(val, batch_size=BATCH_SIZE)

# mockups for MAP compatible data
mock_preds = [
    dict(
        boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]),
        scores=torch.Tensor([0.536]),
        labels=torch.IntTensor([0]),
    )
]
mock_target = [
    dict(
        boxes=torch.Tensor([[214.0, 41.0, 562.0, 285.0]]),
        labels=torch.IntTensor([0]),
    )
]


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.val_map = MAP(class_metrics=True, dist_sync_on_step=True)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)

        # ignore real outputs and add mockup preds to metric
        preds = []
        target = []

        for n in range(batch.size(0)):
            x = mock_preds[0]
            preds.append({
                'boxes': x['boxes'].to(self.device),
                'labels': x['labels'].to(self.device),
                'scores': x['scores'].to(self.device)
            })

            x = mock_target[0]
            target.append({
                'boxes': x['boxes'].to(self.device),
                'labels': x['labels'].to(self.device)
            })

        self.val_map.update(preds=preds, target=target)
        return {"x": loss}

    def on_validation_epoch_start(self) -> None:
        self.val_map.reset()

    def on_validation_epoch_end(self) -> None:
        if self.trainer.global_step != 0:
            print(
                f"\nRunning val metric on {len(self.val_map.groundtruth_boxes)} samples"
            )
            start = time.time()
            result = self.val_map.compute()  # GPUs get stuck here
            end = time.time()
            print(f"Total time: {end-start}")
            print(f"Time per sample {(time.time() - start) / len(self.val_map.groundtruth_boxes)}")

    def configure_optimizers(self):
        return [torch.optim.SGD(self.layer.parameters(), lr=0.1)]


model = BoringModel()
trainer = Trainer(
    max_epochs=10,
    strategy='ddp',
    gpus=1
)

trainer.fit(model, train, val)

Expected behavior

Performance should at least be on par with the pycocotools implementation.

Best case, it should be faster, especially on GPU.

Environment

  • PyTorch Version (e.g., 1.0): 1.0
  • OS (e.g., Linux): Windows 10
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.9
  • CUDA/cuDNN version: 11.3
  • GPU models and configuration: GTX 1060
  • Any other relevant information:

Additional context

@tkupek tkupek added bug / fix Something isn't working help wanted Extra attention is needed labels Dec 10, 2021
@Borda Borda added this to the v0.7 milestone Dec 10, 2021
@Borda Borda added enhancement New feature or request and removed bug / fix Something isn't working labels Dec 10, 2021
@OlofHarrysson
Copy link
Contributor

OlofHarrysson commented Dec 14, 2021

I've come across a profiling tool that measures time spent per line. I'll share a modified version of your code @tkupek

import time

import torch
import torch.distributed
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import Dataset, DataLoader

from torchmetrics.detection.map import MAP


BATCH_SIZE = 32


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


train = RandomDataset(32, 40 * BATCH_SIZE)
train = DataLoader(train, batch_size=BATCH_SIZE)

val = RandomDataset(32, 40 * BATCH_SIZE)
val = DataLoader(val, batch_size=BATCH_SIZE)

# mockups for MAP compatible data
mock_preds = [
    dict(
        boxes=torch.Tensor([[258.0, 41.0, 606.0, 285.0]]),
        scores=torch.Tensor([0.536]),
        labels=torch.IntTensor([0]),
    )
]
mock_target = [
    dict(
        boxes=torch.Tensor([[214.0, 41.0, 562.0, 285.0]]),
        labels=torch.IntTensor([0]),
    )
]


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.val_map = MAP(class_metrics=True, dist_sync_on_step=True)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)

        # ignore real outputs and add mockup preds to metric
        preds = []
        target = []

        for n in range(batch.size(0)):
            x = mock_preds[0]
            preds.append(
                {
                    "boxes": x["boxes"].to(self.device),
                    "labels": x["labels"].to(self.device),
                    "scores": x["scores"].to(self.device),
                }
            )

            x = mock_target[0]
            target.append({"boxes": x["boxes"].to(self.device), "labels": x["labels"].to(self.device)})

        self.val_map.update(preds=preds, target=target)
        return {"x": loss}

    def on_validation_epoch_start(self) -> None:
        self.val_map.reset()

    def on_validation_epoch_end(self) -> None:
        if self.trainer.global_step != 0:
            print(f"\nRunning val metric on {len(self.val_map.groundtruth_boxes)} samples")
            start = time.time()
            result = self.val_map.compute()  # GPUs get stuck here
            end = time.time()
            print(f"Total time: {end-start}")
            print(f"Time per sample {(time.time() - start) / len(self.val_map.groundtruth_boxes)}")

    def configure_optimizers(self):
        return [torch.optim.SGD(self.layer.parameters(), lr=0.1)]


# pip install line_profiler
# Line profiler is NOT accurate for CUDA code.
import contextlib
from typing import List, Callable

import line_profiler


class WrappedLineProfiler(line_profiler.LineProfiler):
    """Measures time for executing code in the specified profiling_functions.
    More info: https://github.com/pyutils/line_profiler

    Call the print_stats() method after profiling to get results"""

    def __init__(self, profiling_functions: List[Callable]):
        super().__init__(*profiling_functions)

    @contextlib.contextmanager
    def __call__(self):
        self.enable()  # Start measuring time
        yield  # profiling_functions are expected to run here
        self.disable()  # Stop measuring time


model = BoringModel()
trainer = Trainer(max_epochs=1, strategy="ddp", gpus=None)

profiling_functions = [
    MAP.compute,
    MAP._calculate,
    MAP._evaluate_image,
    MAP._find_best_gt_match,
]
profiler = WrappedLineProfiler(profiling_functions)

with profiler():
    trainer.fit(model, train, val)
profiler.print_stats()

This outputs

Timer unit: 1e-06 s

Total time: 6.58653 s
File: /Users/olof/miniforge3_x86_64/envs/metric/lib/python3.8/site-packages/torchmetrics/detection/map.py
Function: _evaluate_image at line 342

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   342                                               def _evaluate_image(
   343                                                   self, id: int, class_id: int, area_range: Tuple[int, int], max_det: int, ious: dict
   344                                               ) -> Optional[dict]:
   345                                                   """Perform evaluation for single class and image.
   346
   347                                                   Args:
   348                                                       id:
   349                                                           Image Id, equivalent to the index of supplied samples.
   350                                                       class_id:
   351                                                           Class Id of the supplied ground truth and detection labels.
   352                                                       area_range:
   353                                                           List of lower and upper bounding box area threshold.
   354                                                       max_det:
   355                                                           Maximum number of evaluated detection bounding boxes.
   356                                                       ious:
   357                                                           IoU results for image and class.
   358                                                   """
   359     10240      11579.0      1.1      0.2          gt = self.groundtruth_boxes[id]
   360     10240      10637.0      1.0      0.2          det = self.detection_boxes[id]
   361     10240      68231.0      6.7      1.0          gt_label_mask = self.groundtruth_labels[id] == class_id
   362     10240      55527.0      5.4      0.8          det_label_mask = self.detection_labels[id] == class_id
   363     10240      48872.0      4.8      0.7          if len(det_label_mask) == 0 or len(det_label_mask) == 0:
   364                                                       return None
   365     10240      98945.0      9.7      1.5          gt = gt[gt_label_mask]
   366     10240      81480.0      8.0      1.2          det = det[det_label_mask]
   367     10240      30070.0      2.9      0.5          if len(gt) == 0 and len(det) == 0:
   368                                                       return None
   369
   370     10240     201581.0     19.7      3.1          areas = box_area(gt)
   371     10240     144051.0     14.1      2.2          ignore_area = (areas < area_range[0]) | (areas > area_range[1])
   372
   373                                                   # sort dt highest score first, sort gt ignore last
   374     10240      87979.0      8.6      1.3          ignore_area_sorted, gtind = torch.sort(ignore_area)
   375     10240      66937.0      6.5      1.0          gt = gt[gtind]
   376     10240      11904.0      1.2      0.2          scores = self.detection_scores[id]
   377     10240      81570.0      8.0      1.2          scores_filtered = scores[det_label_mask]
   378     10240      85413.0      8.3      1.3          scores_sorted, dtind = torch.sort(scores_filtered, descending=True)
   379     10240      62781.0      6.1      1.0          det = det[dtind]
   380     10240     101055.0      9.9      1.5          if len(det) > max_det:
   381                                                       det = det[:max_det]
   382                                                   # load computed ious
   383     10240     101549.0      9.9      1.5          ious = ious[id, class_id][:, gtind] if len(ious[id, class_id]) > 0 else ious[id, class_id]
   384
   385     10240      29728.0      2.9      0.5          nb_iou_thrs = len(self.iou_thresholds)
   386     10240      25793.0      2.5      0.4          nb_gt = len(gt)
   387     10240      25232.0      2.5      0.4          nb_det = len(det)
   388     10240      36515.0      3.6      0.6          gt_matches = torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool)
   389     10240      28745.0      2.8      0.4          det_matches = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool)
   390     10240      10208.0      1.0      0.2          gt_ignore = ignore_area_sorted
   391     10240      28023.0      2.7      0.4          det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool)
   392     10240      13158.0      1.3      0.2          if torch.numel(ious) > 0:
   393    112640     215248.0      1.9      3.3              for idx_iou, t in enumerate(self.iou_thresholds):
   394    204800     232445.0      1.1      3.5                  for idx_det in range(nb_det):
   395    102400    2750873.0     26.9     41.8                      m = MAP._find_best_gt_match(t, nb_gt, gt_matches, idx_iou, gt_ignore, ious, idx_det)
   396    102400     104717.0      1.0      1.6                      if m is not -1:
   397     61440     327357.0      5.3      5.0                          det_ignore[idx_iou, idx_det] = gt_ignore[m]
   398     61440     313011.0      5.1      4.8                          det_matches[idx_iou, idx_det] = True
   399     61440     303296.0      4.9      4.6                          gt_matches[idx_iou, m] = True
   400                                                   # set unmatched detections outside of area range to ignore
   401     10240     198475.0     19.4      3.0          det_areas = box_area(det)
   402     10240     143230.0     14.0      2.2          det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1])
   403     10240      35723.0      3.5      0.5          ar = det_ignore_area.reshape((1, nb_det))
   404     20480      57541.0      2.8      0.9          det_ignore = torch.logical_or(
   405     10240     297028.0     29.0      4.5              det_ignore, torch.logical_and(det_matches == 0, torch.repeat_interleave(ar, nb_iou_thrs, 0))
   406                                                   )
   407     10240      11387.0      1.1      0.2          return {
   408     10240      10184.0      1.0      0.2              "dtMatches": det_matches,
   409     10240       9665.0      0.9      0.1              "gtMatches": gt_matches,
   410     10240       9605.0      0.9      0.1              "dtScores": scores_sorted,
   411     10240       9600.0      0.9      0.1              "gtIgnore": gt_ignore,
   412     10240       9587.0      0.9      0.1              "dtIgnore": det_ignore,
   413                                                   }

Total time: 2.12782 s
File: /Users/olof/miniforge3_x86_64/envs/metric/lib/python3.8/site-packages/torchmetrics/detection/map.py
Function: _find_best_gt_match at line 415

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   415                                               @staticmethod
   416                                               def _find_best_gt_match(
   417                                                   thr: int, nb_gt: int, gt_matches: Tensor, idx_iou: float, gt_ignore: Tensor, ious: Tensor, idx_det: int
   418                                               ) -> int:
   419                                                   """Return id of best ground truth match with current detection.
   420
   421                                                   Args:
   422                                                       thr:
   423                                                           Current threshold value.
   424                                                       nb_gt:
   425                                                           Number of ground truth elements.
   426                                                       gt_matches:
   427                                                           Tensor showing if a ground truth matches for threshold ``t`` exists.
   428                                                       idx_iou:
   429                                                           Id of threshold ``t``.
   430                                                       gt_ignore:
   431                                                           Tensor showing if ground truth should be ignored.
   432                                                       ious:
   433                                                           IoUs for all combinations of detection and ground truth.
   434                                                       idx_det:
   435                                                           Id of current detection.
   436                                                   """
   437                                                   # information about best match so far (m=-1 -> unmatched)
   438    102400     718494.0      7.0     33.8          iou = min([thr, 1 - 1e-10])
   439    102400      50142.0      0.5      2.4          match_id = -1
   440    204800     108606.0      0.5      5.1          for idx_gt in range(nb_gt):
   441                                                       # if this gt already matched, and not a crowd, continue
   442    102400     397433.0      3.9     18.7              if gt_matches[idx_iou, idx_gt]:
   443                                                           continue
   444                                                       # if dt matched to reg gt, and on ignore gt, stop
   445    102400      47277.0      0.5      2.2              if match_id > -1 and not gt_ignore[match_id] and gt_ignore[idx_gt]:
   446                                                           break
   447                                                       # continue to next gt unless better match made
   448    102400     552561.0      5.4     26.0              if ious[idx_det, idx_gt] < iou:
   449     40960      18751.0      0.5      0.9                  continue
   450                                                       # if match successful and best so far, store appropriately
   451     61440     168800.0      2.7      7.9              iou = ious[idx_det, idx_gt]
   452     61440      26907.0      0.4      1.3              match_id = idx_gt
   453    102400      38849.0      0.4      1.8          return match_id

Total time: 9.7553 s
File: /Users/olof/miniforge3_x86_64/envs/metric/lib/python3.8/site-packages/torchmetrics/detection/map.py
Function: _calculate at line 498

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   498                                               def _calculate(self, class_ids: List) -> Tuple[Dict, MAPMetricResults, MARMetricResults]:
   499                                                   """Calculate the precision, recall and scores for all supplied label classes to calculate mAP/mAR.
   500
   501                                                   Args:
   502                                                       class_ids:
   503                                                           List of label class Ids.
   504                                                   """
   505         2        435.0    217.5      0.0          img_ids = torch.arange(len(self.groundtruth_boxes), dtype=torch.int).tolist()
   506         2         19.0      9.5      0.0          max_detections = self.max_detection_thresholds[-1]
   507         2          4.0      2.0      0.0          area_ranges = self.bbox_area_ranges.values()
   508
   509         4     387382.0  96845.5      4.0          ious = {
   510         2          2.0      1.0      0.0              (id, class_id): self._compute_iou(id, class_id, max_detections) for id in img_ids for class_id in class_ids
   511                                                   }
   512
   513         4    8015838.0 2003959.5     82.2          eval_imgs = [
   514                                                       self._evaluate_image(img_id, class_id, area, max_detections, ious)
   515         2          2.0      1.0      0.0              for class_id in class_ids
   516                                                       for area in area_ranges
   517                                                       for img_id in img_ids
   518                                                   ]
   519
   520         2          9.0      4.5      0.0          nb_iou_thrs = len(self.iou_thresholds)
   521         2          9.0      4.5      0.0          nb_rec_thrs = len(self.rec_thresholds)
   522         2          4.0      2.0      0.0          nb_classes = len(class_ids)
   523         2          4.0      2.0      0.0          nb_bbox_areas = len(self.bbox_area_ranges)
   524         2          9.0      4.5      0.0          nb_max_det_thrs = len(self.max_detection_thresholds)
   525         2          4.0      2.0      0.0          nb_imgs = len(img_ids)
   526         2       1588.0    794.0      0.0          precision = -torch.ones((nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs))
   527         2         17.0      8.5      0.0          recall = -torch.ones((nb_iou_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs))
   528         2         35.0     17.5      0.0          scores = -torch.ones((nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs))
   529
   530                                                   # retrieve E at each category, area range, and max number of detections
   531         4          7.0      1.8      0.0          for idx_cls in range(nb_classes):
   532        10         18.0      1.8      0.0              for idx_bbox_area in range(nb_bbox_areas):
   533        32        152.0      4.8      0.0                  for idx_max_det_thrs, max_det in enumerate(self.max_detection_thresholds):
   534        48    1346035.0  28042.4     13.8                      recall, precision, scores = MAP.__calculate_recall_precision_scores(
   535        24         34.0      1.4      0.0                          recall,
   536        24         36.0      1.5      0.0                          precision,
   537        24         38.0      1.6      0.0                          scores,
   538        24         32.0      1.3      0.0                          idx_cls=idx_cls,
   539        24         38.0      1.6      0.0                          idx_bbox_area=idx_bbox_area,
   540        24         35.0      1.5      0.0                          idx_max_det_thrs=idx_max_det_thrs,
   541        24         38.0      1.6      0.0                          eval_imgs=eval_imgs,
   542        24         46.0      1.9      0.0                          rec_thresholds=self.rec_thresholds,
   543        24         38.0      1.6      0.0                          max_det=max_det,
   544        24         34.0      1.4      0.0                          nb_imgs=nb_imgs,
   545        24         38.0      1.6      0.0                          nb_bbox_areas=nb_bbox_areas,
   546                                                               )
   547
   548         2          9.0      4.5      0.0          results = {
   549         2          5.0      2.5      0.0              "dimensions": [nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs],
   550         2          2.0      1.0      0.0              "precision": precision,
   551         2          4.0      2.0      0.0              "recall": recall,
   552         2          2.0      1.0      0.0              "scores": scores,
   553                                                   }
   554
   555         2          5.0      2.5      0.0          map_metrics = MAPMetricResults()
   556         2        832.0    416.0      0.0          map_metrics.map = self._summarize(results, True)
   557         2          9.0      4.5      0.0          last_max_det_thr = self.max_detection_thresholds[-1]
   558         2        609.0    304.5      0.0          map_metrics.map_50 = self._summarize(results, True, iou_threshold=0.5, max_dets=last_max_det_thr)
   559         2        221.0    110.5      0.0          map_metrics.map_75 = self._summarize(results, True, iou_threshold=0.75, max_dets=last_max_det_thr)
   560         2        354.0    177.0      0.0          map_metrics.map_small = self._summarize(results, True, area_range="small", max_dets=last_max_det_thr)
   561         2        129.0     64.5      0.0          map_metrics.map_medium = self._summarize(results, True, area_range="medium", max_dets=last_max_det_thr)
   562         2        213.0    106.5      0.0          map_metrics.map_large = self._summarize(results, True, area_range="large", max_dets=last_max_det_thr)
   563
   564         2          6.0      3.0      0.0          mar_metrics = MARMetricResults()
   565         8         27.0      3.4      0.0          for max_det in self.max_detection_thresholds:
   566         6        505.0     84.2      0.0              mar_metrics[f"mar_{max_det}"] = self._summarize(results, False, max_dets=max_det)
   567         2        115.0     57.5      0.0          mar_metrics.mar_small = self._summarize(results, False, area_range="small", max_dets=last_max_det_thr)
   568         2        112.0     56.0      0.0          mar_metrics.mar_medium = self._summarize(results, False, area_range="medium", max_dets=last_max_det_thr)
   569         2        155.0     77.5      0.0          mar_metrics.mar_large = self._summarize(results, False, area_range="large", max_dets=last_max_det_thr)
   570
   571         2          4.0      2.0      0.0          return results, map_metrics, mar_metrics

Total time: 9.78097 s
File: /Users/olof/miniforge3_x86_64/envs/metric/lib/python3.8/site-packages/torchmetrics/detection/map.py
Function: compute at line 641

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   641                                               def compute(self) -> dict:
   642                                                   """Compute the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)` scores.
   643
   644                                                   Note:
   645                                                       `map` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ]
   646
   647                                                       Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well.
   648                                                       The default properties are also accessible via fields and will raise an ``AttributeError`` if not available.
   649
   650                                                   Returns:
   651                                                       dict containing
   652
   653                                                       - map: ``torch.Tensor``
   654                                                       - map_50: ``torch.Tensor``
   655                                                       - map_75: ``torch.Tensor``
   656                                                       - map_small: ``torch.Tensor``
   657                                                       - map_medium: ``torch.Tensor``
   658                                                       - map_large: ``torch.Tensor``
   659                                                       - mar_1: ``torch.Tensor``
   660                                                       - mar_10: ``torch.Tensor``
   661                                                       - mar_100: ``torch.Tensor``
   662                                                       - mar_small: ``torch.Tensor``
   663                                                       - mar_medium: ``torch.Tensor``
   664                                                       - mar_large: ``torch.Tensor``
   665                                                       - map_per_class: ``torch.Tensor`` (-1 if class metrics are disabled)
   666                                                       - mar_100_per_class: ``torch.Tensor`` (-1 if class metrics are disabled)
   667                                                   """
   668         1    4960085.0 4960085.0     50.7          overall, map, mar = self._calculate(self._get_classes())
   669
   670         1         43.0     43.0      0.0          map_per_class_values: Tensor = Tensor([-1])
   671         1          4.0      4.0      0.0          mar_max_dets_per_class_values: Tensor = Tensor([-1])
   672
   673                                                   # if class mode is enabled, evaluate metrics per class
   674         1          4.0      4.0      0.0          if self.class_metrics:
   675         1          1.0      1.0      0.0              map_per_class_list = []
   676         1          1.0      1.0      0.0              mar_max_dets_per_class_list = []
   677
   678         2        503.0    251.5      0.0              for class_id in self._get_classes():
   679         1    4820217.0 4820217.0     49.3                  _, cls_map, cls_mar = self._calculate([class_id])
   680         1         15.0     15.0      0.0                  map_per_class_list.append(cls_map.map)
   681         1         37.0     37.0      0.0                  mar_max_dets_per_class_list.append(cls_mar[f"mar_{self.max_detection_thresholds[-1]}"])
   682
   683         1         34.0     34.0      0.0              map_per_class_values = Tensor(map_per_class_list)
   684         1          5.0      5.0      0.0              mar_max_dets_per_class_values = Tensor(mar_max_dets_per_class_list)
   685
   686         1          1.0      1.0      0.0          metrics = COCOMetricResults()
   687         1          2.0      2.0      0.0          metrics.update(map)
   688         1          2.0      2.0      0.0          metrics.update(mar)
   689         1          8.0      8.0      0.0          metrics.map_per_class = map_per_class_values
   690         1          6.0      6.0      0.0          metrics[f"mar_{self.max_detection_thresholds[-1]}_per_class"] = mar_max_dets_per_class_values
   691         1          1.0      1.0      0.0          return metrics

I used this information to find that line 438 iou = min([thr, 1 - 1e-10]) takes a significant amount of time. It's quicker to convert the threshold to a float beforehand i.e. iou = min([thr.item(), 1 - 1e-10]).

I'm running a M1-macbook so haven't tried anything on the GPU.

@tkupek
Copy link
Contributor Author

tkupek commented Dec 15, 2021

@OlofHarrysson this is good insight! I can test this today on a GPU.
@twsl FYI

@OlofHarrysson
Copy link
Contributor

@tkupek Please do :)

Note that profiling CUDA calls with this profiler can be incorrect. Actions performed after CUDA operations, e.g. moving a tensor from gpu to cpu with tensor.cpu(), will often incorrectly be attributed to the .cpu() line while Python is actually blocked/waiting for the CUDA call to finish.

You basically have to sprinkle in a bunch of torch.cuda.synchronize() to wait for the cuda code to finish if you want the correct time spent on each line. I'd advise using this profiler to get a rough estimate for what code is the slowest and then use another tool to measure any problematic parts.

@tkupek
Copy link
Contributor Author

tkupek commented Dec 15, 2021

I ran the profiler on a GPU.

Function: _evaluate_image at line 378

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   378                                               def _evaluate_image(
   379                                                   self, id: int, class_id: int, area_range: Tuple[int, int], max_det: int, ious: dict
   380                                               ) -> Optional[dict]:
   381                                                   """Perform evaluation for single class and image.
   382                                           
   383                                                   Args:
   384                                                       id:
   385                                                           Image Id, equivalent to the index of supplied samples.
   386                                                       class_id:
   387                                                           Class Id of the supplied ground truth and detection labels.
   388                                                       area_range:
   389                                                           List of lower and upper bounding box area threshold.
   390                                                       max_det:
   391                                                           Maximum number of evaluated detection bounding boxes.
   392                                                       ious:
   393                                                           IoU results for image and class.
   394                                                   """
   395     10240     274326.0     26.8      0.1          gt = self.groundtruth_boxes[id]
   396     10240     215817.0     21.1      0.1          det = self.detection_boxes[id]
   397     10240    2654947.0    259.3      1.2          gt_label_mask = self.groundtruth_labels[id] == class_id
   398     10240    1708472.0    166.8      0.8          det_label_mask = self.detection_labels[id] == class_id
   399     10240    2305276.0    225.1      1.1          if len(det_label_mask) == 0 or len(det_label_mask) == 0:
   400                                                       return None
   401     10240    4707926.0    459.8      2.2          gt = gt[gt_label_mask]
   402     10240    2215483.0    216.4      1.0          det = det[det_label_mask]
   403     10240     954598.0     93.2      0.4          if len(gt) == 0 and len(det) == 0:
   404                                                       return None
   405                                           
   406     10240    8138888.0    794.8      3.8          areas = box_area(gt)
   407     10240    5400976.0    527.4      2.5          ignore_area = (areas < area_range[0]) | (areas > area_range[1])
   408                                           
   409                                                   # sort dt highest score first, sort gt ignore last
   410     10240    5884955.0    574.7      2.8          ignore_area_sorted, gtind = torch.sort(ignore_area.to(torch.uint8))
   411                                                   # Convert to uint8 temporarily and back to bool, because "Sort currently does not support bool dtype on CUDA"
   412     10240    1289357.0    125.9      0.6          ignore_area_sorted = ignore_area_sorted.to(torch.bool)
   413     10240    2528658.0    246.9      1.2          gt = gt[gtind]
   414     10240     314471.0     30.7      0.1          scores = self.detection_scores[id]
   415     10240    2465126.0    240.7      1.2          scores_filtered = scores[det_label_mask]
   416     10240    3253308.0    317.7      1.5          scores_sorted, dtind = torch.sort(scores_filtered, descending=True)
   417     10240    1882477.0    183.8      0.9          det = det[dtind]
   418     10240    4426422.0    432.3      2.1          if len(det) > max_det:
   419                                                       det = det[:max_det]
   420                                                   # load computed ious
   421     10240    3931997.0    384.0      1.8          ious = ious[id, class_id][:, gtind] if len(ious[id, class_id]) > 0 else ious[id, class_id]
   422                                           
   423     10240     850791.0     83.1      0.4          nb_iou_thrs = len(self.iou_thresholds)
   424     10240     575727.0     56.2      0.3          nb_gt = len(gt)
   425     10240     502157.0     49.0      0.2          nb_det = len(det)
   426     10240    2208119.0    215.6      1.0          gt_matches = torch.zeros((nb_iou_thrs, nb_gt), dtype=torch.bool, device=self.device)
   427     10240     908242.0     88.7      0.4          det_matches = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device)
   428     10240     151761.0     14.8      0.1          gt_ignore = ignore_area_sorted
   429     10240     788307.0     77.0      0.4          det_ignore = torch.zeros((nb_iou_thrs, nb_det), dtype=torch.bool, device=self.device)
   430     10240     439961.0     43.0      0.2          if torch.numel(ious) > 0:
   431    112640    6776639.0     60.2      3.2              for idx_iou, t in enumerate(self.iou_thresholds):
   432    204800    4494709.0     21.9      2.1                  for idx_det in range(nb_det):
   433    102400   79464542.0    776.0     37.3                      m = MAP._find_best_gt_match(t, nb_gt, gt_matches, idx_iou, gt_ignore, ious, idx_det)
   434    102400    1642236.0     16.0      0.8                      if m != -1:
   435     61440   10818807.0    176.1      5.1                          det_ignore[idx_iou, idx_det] = gt_ignore[m]
   436     61440    9021856.0    146.8      4.2                          det_matches[idx_iou, idx_det] = True
   437     61440    8329511.0    135.6      3.9                          gt_matches[idx_iou, m] = True
   438                                                   # set unmatched detections outside of area range to ignore
   439     10240    8939618.0    873.0      4.2          det_areas = box_area(det)
   440     10240    6049290.0    590.8      2.8          det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1])
   441     10240    1796892.0    175.5      0.8          ar = det_ignore_area.reshape((1, nb_det))
   442     20480    1865023.0     91.1      0.9          det_ignore = torch.logical_or(
   443     10240   12177737.0   1189.2      5.7              det_ignore, torch.logical_and(det_matches == 0, torch.repeat_interleave(ar, nb_iou_thrs, 0))
   444                                                   )
   445     10240     251404.0     24.6      0.1          return {
   446     10240     171059.0     16.7      0.1              "dtMatches": det_matches,
   447     10240     122665.0     12.0      0.1              "gtMatches": gt_matches,
   448     10240     122553.0     12.0      0.1              "dtScores": scores_sorted,
   449     10240     119976.0     11.7      0.1              "gtIgnore": gt_ignore,
   450     10240     118389.0     11.6      0.1              "dtIgnore": det_ignore,
   451                                                   }

Total time: 6.49406 s
File: C:\Users\tkupek\devel\torchmetrics\torchmetrics\detection\map.py
Function: _find_best_gt_match at line 453

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   453                                               @staticmethod
   454                                               def _find_best_gt_match(
   455                                                   thr: int, nb_gt: int, gt_matches: Tensor, idx_iou: float, gt_ignore: Tensor, ious: Tensor, idx_det: int
   456                                               ) -> int:
   457                                                   """Return id of best ground truth match with current detection.
   458                                           
   459                                                   Args:
   460                                                       thr:
   461                                                           Current threshold value.
   462                                                       nb_gt:
   463                                                           Number of ground truth elements.
   464                                                       gt_matches:
   465                                                           Tensor showing if a ground truth matches for threshold ``t`` exists.
   466                                                       idx_iou:
   467                                                           Id of threshold ``t``.
   468                                                       gt_ignore:
   469                                                           Tensor showing if ground truth should be ignored.
   470                                                       ious:
   471                                                           IoUs for all combinations of detection and ground truth.
   472                                                       idx_det:
   473                                                           Id of current detection.
   474                                                   """
   475                                                   # information about best match so far (m=-1 -> unmatched)
   476    102400    6374074.0     62.2      9.8          iou = min([thr.item(), 1 - 1e-10])
   477    102400    1041030.0     10.2      1.6          match_id = -1
   478    204800    2648199.0     12.9      4.1          for idx_gt in range(nb_gt):
   479                                                       # if this gt already matched, and not a crowd, continue
   480    102400   13770490.0    134.5     21.2              if gt_matches[idx_iou, idx_gt]:
   481                                                           continue
   482                                                       # if dt matched to reg gt, and on ignore gt, stop
   483    102400    1269101.0     12.4      2.0              if match_id > -1 and not gt_ignore[match_id] and gt_ignore[idx_gt]:
   484                                                           break
   485                                                       # continue to next gt unless better match made
   486    102400   31630426.0    308.9     48.7              if ious[idx_det, idx_gt] < iou:
   487     40960     592879.0     14.5      0.9                  continue
   488                                                       # if match successful and best so far, store appropriately
   489     61440    6024284.0     98.1      9.3              iou = ious[idx_det, idx_gt]
   490     61440     664011.0     10.8      1.0              match_id = idx_gt
   491    102400     926110.0      9.0      1.4          return match_id

Total time: 28.7968 s
File: C:\Users\tkupek\devel\torchmetrics\torchmetrics\detection\map.py
Function: _calculate at line 536

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   536                                               def _calculate(self, class_ids: List) -> Tuple[Dict, MAPMetricResults, MARMetricResults]:
   537                                                   """Calculate the precision, recall and scores for all supplied label classes to calculate mAP/mAR.
   538                                           
   539                                                   Args:
   540                                                       class_ids:
   541                                                           List of label class Ids.
   542                                                   """
   543         2       1688.0    844.0      0.0          img_ids = torch.arange(len(self.groundtruth_boxes), dtype=torch.int).tolist()
   544         2        377.0    188.5      0.0          max_detections = self.max_detection_thresholds[-1]
   545         2         70.0     35.0      0.0          area_ranges = self.bbox_area_ranges.values()
   546                                           
   547         4   11070493.0 2767623.2      3.8          ious = {
   548         2         29.0     14.5      0.0              (id, class_id): self._compute_iou(id, class_id, max_detections) for id in img_ids for class_id in class_ids
   549                                                   }
   550                                           
   551         4  231073205.0 57768301.2     80.2          eval_imgs = [
   552                                                       self._evaluate_image(img_id, class_id, area, max_detections, ious)
   553         2         58.0     29.0      0.0              for class_id in class_ids
   554                                                       for area in area_ranges
   555                                                       for img_id in img_ids
   556                                                   ]
   557                                           
   558         2        554.0    277.0      0.0          nb_iou_thrs = len(self.iou_thresholds)
   559         2        171.0     85.5      0.0          nb_rec_thrs = len(self.rec_thresholds)
   560         2         49.0     24.5      0.0          nb_classes = len(class_ids)
   561         2         77.0     38.5      0.0          nb_bbox_areas = len(self.bbox_area_ranges)
   562         2        117.0     58.5      0.0          nb_max_det_thrs = len(self.max_detection_thresholds)
   563         2         41.0     20.5      0.0          nb_imgs = len(img_ids)
   564         2       3187.0   1593.5      0.0          precision = -torch.ones((nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs))
   565         2        961.0    480.5      0.0          recall = -torch.ones((nb_iou_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs))
   566         2        870.0    435.0      0.0          scores = -torch.ones((nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs))
   567                                           
   568                                                   # move tensors if necessary
   569         2        890.0    445.0      0.0          self.max_detection_thresholds = self.max_detection_thresholds.to(self.device)
   570         2        322.0    161.0      0.0          self.rec_thresholds = self.rec_thresholds.to(self.device)
   571                                           
   572                                                   # retrieve E at each category, area range, and max number of detections
   573         4        140.0     35.0      0.0          for idx_cls in range(nb_classes):
   574        10        336.0     33.6      0.0              for idx_bbox_area in range(nb_bbox_areas):
   575        32       6336.0    198.0      0.0                  for idx_max_det_thrs, max_det in enumerate(self.max_detection_thresholds):
   576        48   45721735.0 952536.1     15.9                      recall, precision, scores = MAP.__calculate_recall_precision_scores(
   577        24        558.0     23.2      0.0                          recall,
   578        24        517.0     21.5      0.0                          precision,
   579        24        490.0     20.4      0.0                          scores,
   580        24        493.0     20.5      0.0                          idx_cls=idx_cls,
   581        24        496.0     20.7      0.0                          idx_bbox_area=idx_bbox_area,
   582        24        501.0     20.9      0.0                          idx_max_det_thrs=idx_max_det_thrs,
   583        24        517.0     21.5      0.0                          eval_imgs=eval_imgs,
   584        24        887.0     37.0      0.0                          rec_thresholds=self.rec_thresholds,
   585        24        539.0     22.5      0.0                          max_det=max_det,
   586        24        544.0     22.7      0.0                          nb_imgs=nb_imgs,
   587        24        963.0     40.1      0.0                          nb_bbox_areas=nb_bbox_areas,
   588                                                               )
   589                                           
   590         2         61.0     30.5      0.0          results = {
   591         2         52.0     26.0      0.0              "dimensions": [nb_iou_thrs, nb_rec_thrs, nb_classes, nb_bbox_areas, nb_max_det_thrs],
   592         2         37.0     18.5      0.0              "precision": precision,
   593         2         36.0     18.0      0.0              "recall": recall,
   594         2         36.0     18.0      0.0              "scores": scores,
   595                                                   }
   596                                           
   597         2         79.0     39.5      0.0          map_metrics = MAPMetricResults()
   598         2      12382.0   6191.0      0.0          map_metrics.map = self._summarize(results, True)
   599         2        392.0    196.0      0.0          last_max_det_thr = self.max_detection_thresholds[-1]
   600         2       8417.0   4208.5      0.0          map_metrics.map_50 = self._summarize(results, True, iou_threshold=0.5, max_dets=last_max_det_thr)
   601         2       6567.0   3283.5      0.0          map_metrics.map_75 = self._summarize(results, True, iou_threshold=0.75, max_dets=last_max_det_thr)
   602         2       4357.0   2178.5      0.0          map_metrics.map_small = self._summarize(results, True, area_range="small", max_dets=last_max_det_thr)
   603         2       4009.0   2004.5      0.0          map_metrics.map_medium = self._summarize(results, True, area_range="medium", max_dets=last_max_det_thr)
   604         2      13828.0   6914.0      0.0          map_metrics.map_large = self._summarize(results, True, area_range="large", max_dets=last_max_det_thr)
   605                                           
   606         2         97.0     48.5      0.0          mar_metrics = MARMetricResults()
   607         8        873.0    109.1      0.0          for max_det in self.max_detection_thresholds:
   608         6      17195.0   2865.8      0.0              mar_metrics[f"mar_{max_det}"] = self._summarize(results, False, max_dets=max_det)
   609         2       2991.0   1495.5      0.0          mar_metrics.mar_small = self._summarize(results, False, area_range="small", max_dets=last_max_det_thr)
   610         2       3278.0   1639.0      0.0          mar_metrics.mar_medium = self._summarize(results, False, area_range="medium", max_dets=last_max_det_thr)
   611         2       4739.0   2369.5      0.0          mar_metrics.mar_large = self._summarize(results, False, area_range="large", max_dets=last_max_det_thr)
   612                                           
   613         2         47.0     23.5      0.0          return results, map_metrics, mar_metrics

Total time: 28.8363 s
File: C:\Users\tkupek\devel\torchmetrics\torchmetrics\detection\map.py
Function: compute at line 683

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   683                                               def compute(self) -> dict:
   684                                                   """Compute the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)` scores.
   685                                           
   686                                                   Note:
   687                                                       `map` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ]
   688                                           
   689                                                       Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well.
   690                                                       The default properties are also accessible via fields and will raise an ``AttributeError`` if not available.
   691                                           
   692                                                   Returns:
   693                                                       dict containing
   694                                           
   695                                                       - map: ``torch.Tensor``
   696                                                       - map_50: ``torch.Tensor``
   697                                                       - map_75: ``torch.Tensor``
   698                                                       - map_small: ``torch.Tensor``
   699                                                       - map_medium: ``torch.Tensor``
   700                                                       - map_large: ``torch.Tensor``
   701                                                       - mar_1: ``torch.Tensor``
   702                                                       - mar_10: ``torch.Tensor``
   703                                                       - mar_100: ``torch.Tensor``
   704                                                       - mar_small: ``torch.Tensor``
   705                                                       - mar_medium: ``torch.Tensor``
   706                                                       - mar_large: ``torch.Tensor``
   707                                                       - map_per_class: ``torch.Tensor`` (-1 if class metrics are disabled)
   708                                                       - mar_100_per_class: ``torch.Tensor`` (-1 if class metrics are disabled)
   709                                                   """
   710         1  116630958.0 116630958.0     40.4          overall, map, mar = self._calculate(self._get_classes())
   711                                           
   712         1        696.0    696.0      0.0          map_per_class_values: Tensor = Tensor([-1])
   713         1         59.0     59.0      0.0          mar_max_dets_per_class_values: Tensor = Tensor([-1])
   714                                           
   715                                                   # if class mode is enabled, evaluate metrics per class
   716         1         43.0     43.0      0.0          if self.class_metrics:
   717         1          5.0      5.0      0.0              map_per_class_list = []
   718         1          5.0      5.0      0.0              mar_max_dets_per_class_list = []
   719                                           
   720         2       8240.0   4120.0      0.0              for class_id in self._get_classes():
   721         1  171721016.0 171721016.0     59.6                  _, cls_map, cls_mar = self._calculate([class_id])
   722         1        235.0    235.0      0.0                  map_per_class_list.append(cls_map.map)
   723         1        757.0    757.0      0.0                  mar_max_dets_per_class_list.append(cls_mar[f"mar_{self.max_detection_thresholds[-1]}"])
   724                                           
   725         1        584.0    584.0      0.0              map_per_class_values = Tensor(map_per_class_list)
   726         1        119.0    119.0      0.0              mar_max_dets_per_class_values = Tensor(mar_max_dets_per_class_list)
   727                                           
   728         1         26.0     26.0      0.0          metrics = COCOMetricResults()
   729         1         68.0     68.0      0.0          metrics.update(map)
   730         1         42.0     42.0      0.0          metrics.update(mar)
   731         1         43.0     43.0      0.0          metrics.map_per_class = map_per_class_values
   732         1        181.0    181.0      0.0          metrics[f"mar_{self.max_detection_thresholds[-1]}_per_class"] = mar_max_dets_per_class_values
   733         1         10.0     10.0      0.0          return metrics

Your suggestion helped a bit, but there is a little overall effect:

102400   18327534.0    179.0     35.5          iou = min([thr, 1 - 1e-10])

102400    6374074.0     62.2      9.8          iou = min([thr.item(), 1 - 1e-10])

@OlofHarrysson
Copy link
Contributor

OlofHarrysson commented Dec 17, 2021

Ok. By looking at the time spent on different code parts, there's a section that can be reworked to speed up the calculations by almost ~2 if class_metrics=True.

eval_imgs = [
    self._evaluate_image(img_id, class_id, area, max_detections, ious)
    for class_id in class_ids
    for area in area_ranges
    for img_id in img_ids
]

That code takes a lot of time and is computed two times from calls in the compute() method. The results from self._evaluate_image(...) can be saved to avoid this.

# Code is called here
overall, map, mar = self._calculate(self._get_classes())

# And also here
for class_id in self._get_classes():
    _, cls_map, cls_mar = self._calculate([class_id])
    map_per_class_list.append(cls_map.map)
    mar_max_dets_per_class_list.append(cls_mar[f"mar_{self.max_detection_thresholds[-1]}"])

Seems like an easy fix.

But measuring time on this data can be a bit misleading as there is always one pred and one GT box. Results would differ for different data. It would be good to measure the time on e.g. coco-eval that contains 5000 images with predictions from a standard model.

At any rate, I think the code could be reworked to run faster. _find_best_gt_match is a method that stands out to me since it uses loops. I reworked it to work on arrays, but didn't speed the code up. It would probably work better when there are a larger number of GT-boxes in every image.

@staticmethod
def _find_best_gt_match(
    thr: int, nb_gt: int, gt_matches: Tensor, idx_iou: float, gt_ignore: Tensor, ious: Tensor, idx_det: int
) -> int:
    previously_matched = gt_matches[idx_iou]
    # Remove previously matched or ignored gts
    remove_mask = previously_matched | gt_ignore
    gt_ious = ious[idx_det] * ~remove_mask
    match_idx = gt_ious.argmax().item()
    if gt_ious[match_idx] > thr:
        return match_idx
    return -1

@tkupek
Copy link
Contributor Author

tkupek commented Dec 20, 2021

I do have a real-world detection model + data on hand where I can test your improvements. Will hopefully find some time in the next 1-2 weeks.

@tkupek
Copy link
Contributor Author

tkupek commented Dec 27, 2021

I performed the tests on a real-world dataset and the CUDA issue confirms:

CPU

Running metric on 10 samples on cpu
Total time: 6.439759016036987
Time per sample 0.6439759016036988

CUDA

Running metric on 10 samples on cuda:0
Total time: 120.02612257003784
Time per sample 12.002612257003785

I will now test the performance impact of your suggestions and try to get insights from the profiler.

@Borda Borda modified the milestones: v0.7, v0.8 Jan 6, 2022
@twsl twsl mentioned this issue Jan 11, 2022
4 tasks
@Borda Borda added this to the v0.8 milestone May 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants