Skip to content

Commit

Permalink
3D Detection Eval docstrings + typing fixes. (#40)
Browse files Browse the repository at this point in the history
* Add evaluation.

* Small fixes.

* Add unit tests + eval mask fixes.

* Update detection unit tests.

* Fix typing + update pyproject.

* Run autoflake.

* Add ROI pruning.

* Remove arg.

* Fix typo.

* Speed up argoverse maps.

* Speed up evaluation.

* Small fixes.

* Fix lint.

* Small lint fixes.

* Fix filtering.

* Small fixes.

* Fix enums.

* Remove unused lines.

* Mypy fixes.

* Fix click mypy error.

* Pytype fixes.

* Fix pytype.

* Remove pytype.

* Small typing fixes.

* Add unit tests.

* Fix typing.

* Remove click typing issue.

* Fix mypy.

* Detection eval speed up.

* Rewrite detection eval for major speedup.

* Typing fixes.

* Typing fixes.

* Switch from record arrays to numpy arrays.

* Temp changes.

* Improve readability.

* Add comments.

* Modularize evaluate.

* Additional speedups.

* Cleanup code.

* Additional speedup.

* Add roi pruning back.

* Add multiprocessing.

* Add verbosity.

* Mypy fixes.

* Update cuboid fields.

* Lint fixes.

* Fix map tutorial issues.

* Add test log.

* Revert strings.

* Remove outputs.

* Address missing detection edge cases.

* Address jhony comments.

* Update docstring.

* Clean docstrings.

* Change roi method.

* Clean up roi method.

* Update roi returns.

* Autoflake.:

* Fix lint.

* Fix lint.

* Update detection limiting logic.

* Fix indexing.

* Fix tuple return.

* Update CI.

* Add ROI unit tests.

* Remove val identity.

* Fix import.

* Remove unused import.

* Update column names.

* Update eval.py

* Update detection eval.

* Update docstrings.

* Revert docstring.

* Fix formatting.

* Small fixes.

* Update ci.yml

* Update ci.yml

* Update ci.yml

* Update ci.yml

* Update ci.yml

* Update ci.yml

Co-authored-by: Benjamin Wilson <benjaminrwilson@noreply.github.com>
  • Loading branch information
benjaminrwilson and Benjamin Wilson authored Apr 27, 2022
1 parent 15d43e3 commit 0492dec
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
python_version:
["3.8", "3.9", "3.10"]
venv_backend:
["virtualenv", "mamba"] # Testing two different resolvers (pip, mamba).
["virtualenv"]
defaults:
run:
shell: bash -l {0}
Expand Down
31 changes: 19 additions & 12 deletions src/av2/evaluation/detection/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@
e.g. AP, ATE, ASE, AOE, CDS by default.
"""
import logging
from multiprocessing import get_context
from typing import Dict, Final, List, Optional, Tuple

import numpy as np
import pandas as pd
from joblib import Parallel, delayed

from av2.evaluation.detection.constants import NUM_DECIMALS, MetricNames, TruePositiveErrorNames
from av2.evaluation.detection.utils import (
Expand Down Expand Up @@ -95,7 +95,7 @@ def evaluate(
Each sweep is processed independently, computing assignment between detections and ground truth annotations.
Args:
dts: (N,15) Table of detections.
dts: (N,14) Table of detections.
gts: (M,15) Table of ground truth annotations.
cfg: Detection configuration.
n_jobs: Number of jobs running concurrently during evaluation.
Expand All @@ -106,23 +106,30 @@ def evaluate(
Raises:
RuntimeError: If accumulation fails.
ValueError: If ROI pruning is enabled but a dataset directory is not specified.
"""
if cfg.eval_only_roi_instances and cfg.dataset_dir is None:
raise ValueError(
"ROI pruning has been enabled, but the dataset directory has not be specified. "
"Please set `dataset_directory` to the split root, e.g. av2/sensor/val."
)

# Sort both the detections and annotations by lexicographic order for grouping.
dts = dts.sort_values(list(UUID_COLUMN_NAMES))
gts = gts.sort_values(list(UUID_COLUMN_NAMES))

dts_npy: NDArrayFloat = dts.loc[:, DTS_COLUMN_NAMES].to_numpy()
gts_npy: NDArrayFloat = gts.loc[:, GTS_COLUMN_NAMES].to_numpy()
dts_npy: NDArrayFloat = dts[list(DTS_COLUMN_NAMES)].to_numpy()
gts_npy: NDArrayFloat = gts[list(GTS_COLUMN_NAMES)].to_numpy()

dts_uuids: List[str] = dts.loc[:, UUID_COLUMN_NAMES].to_numpy().astype(str).tolist()
gts_uuids: List[str] = gts.loc[:, UUID_COLUMN_NAMES].to_numpy().astype(str).tolist()
dts_uuids: List[str] = dts[list(UUID_COLUMN_NAMES)].to_numpy().tolist()
gts_uuids: List[str] = gts[list(UUID_COLUMN_NAMES)].to_numpy().tolist()

# We merge the unique identifier -- the tuple of ("log_id", "timestamp_ns", "category")
# into a single string to optimize the subsequent grouping operation.
# `groupby_mapping` produces a mapping from the uuid to the group of detections / annotations
# which fall into that group.
uuid_to_dts = groupby([":".join(x) for x in dts_uuids], dts_npy)
uuid_to_gts = groupby([":".join(x) for x in gts_uuids], gts_npy)
uuid_to_dts = groupby([":".join(map(str, x)) for x in dts_uuids], dts_npy)
uuid_to_gts = groupby([":".join(map(str, x)) for x in gts_uuids], gts_npy)

log_id_to_avm: Optional[Dict[str, ArgoverseStaticMap]] = None
log_id_to_timestamped_poses: Optional[Dict[str, TimestampedCitySE3EgoPoses]] = None
Expand Down Expand Up @@ -154,9 +161,9 @@ def evaluate(
args_list.append(args)

logger.info("Starting evaluation ...")
outputs: Optional[List[Tuple[NDArrayFloat, NDArrayFloat]]] = Parallel(n_jobs=n_jobs, verbose=1)(
delayed(accumulate)(*args) for args in args_list
)
with get_context("spawn").Pool(processes=n_jobs) as p:
outputs: Optional[List[Tuple[NDArrayFloat, NDArrayFloat]]] = p.starmap(accumulate, args_list)

if outputs is None:
raise RuntimeError("Accumulation has failed! Please check the integrity of your detections and annotations.")
dts_list, gts_list = zip(*outputs)
Expand All @@ -182,7 +189,7 @@ def summarize_metrics(
"""Calculate and print the 3D object detection metrics.
Args:
dts: (N,15) Table of detections.
dts: (N,14) Table of detections.
gts: (M,15) Table of ground truth annotations.
cfg: Detection configuration.
Expand Down
2 changes: 0 additions & 2 deletions src/av2/evaluation/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class DetectionCfg:
max_range_m: Max distance (under a specific metric in meters) for a detection or ground truth cuboid to be
considered for evaluation.
num_recall_samples: Number of recall points to sample uniformly in [0, 1].
splits: Tuple of split names to evaluate.
tp_threshold_m: Center distance threshold for the true positive metrics (in meters).
"""

Expand All @@ -70,7 +69,6 @@ class DetectionCfg:
max_num_dts_per_category: int = 100
max_range_m: float = 200.0
num_recall_samples: int = 100
split: str = "val"
tp_threshold_m: float = 2.0

@property
Expand Down
4 changes: 2 additions & 2 deletions src/av2/geometry/utm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
WGS84: https://en.wikipedia.org/wiki/World_Geodetic_System
"""

from enum import unique, Enum
from enum import Enum, unique
from typing import Dict, Final, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -83,7 +83,7 @@ def convert_city_coords_to_utm(points_city: Union[NDArrayFloat, NDArrayInt], cit
latitude, longitude = CITY_ORIGIN_LATLONG_DICT[city_name]
# get (easting, northing) of origin
origin_utm = convert_gps_to_utm(latitude=latitude, longitude=longitude, city_name=city_name)

points_utm: NDArrayFloat = points_city.astype(float) + np.array(origin_utm, dtype=float)
return points_utm

Expand Down
2 changes: 1 addition & 1 deletion tests/geometry/test_utm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

import av2.geometry.utm as geo_utils
from av2.geometry.utm import CityName, CITY_ORIGIN_LATLONG_DICT
from av2.geometry.utm import CityName
from av2.utils.typing import NDArrayFloat


Expand Down

0 comments on commit 0492dec

Please sign in to comment.