From ca10f5154d62d07a7d9afbbd4c50a34970f1fbdc Mon Sep 17 00:00:00 2001 From: Myeong-geun Shin <100759658+furiosamg@users.noreply.github.com> Date: Wed, 29 May 2024 14:48:54 +0900 Subject: [PATCH] Add YOLOv5 Python Postprocesses (#185) * Parameterize anchor, class for yolo * Don't set class attributes * Add YOLOv5 Python postprocesses * Remove num_classes * Return ObjectDetectionResult for YOLO Python PP * Update YOLOv5m/l accuracy * Use new yolo postproc * Update test expectations * Update yolov5m rust accuracy * Update YOLOv5l Rust accuracy * Bump to dev2 * Add accuracy tests, class-aware nms * Use 0.10.0.dev0 furiosa-native-postprocess * Lint, update test oracles * Update yolov5l accuracies * Update accuracy targets * Use torchvision.ops.nms * Update mobilenet rust pp accuracy * Run black, use fnp 0.10.0 release version * Update docs for YOLOv5 --- .gitignore | 7 + docs/models/yolov5l.md | 14 +- docs/models/yolov5m.md | 12 +- furiosa/models/__init__.py | 1 + furiosa/models/types.py | 8 +- furiosa/models/vision/yolov5/core.py | 43 ++-- furiosa/models/vision/yolov5/large.py | 13 +- furiosa/models/vision/yolov5/medium.py | 13 +- furiosa/models/vision/yolov5/postprocess.py | 228 ++++++++++++++++++++ pyproject.toml | 4 +- tests/__init__.py | 0 tests/accuracy/__init__.py | 0 tests/accuracy/test_yolov5l.py | 133 ++++++++++++ tests/accuracy/test_yolov5m.py | 135 ++++++++++++ tests/bench/test_ssd_mobilenet.py | 2 +- tests/bench/test_yolov5l.py | 57 ++++- tests/bench/test_yolov5m.py | 60 +++++- tests/unit/test_batched_yolov5l.py | 2 - tests/unit/test_batched_yolov5m.py | 2 - 19 files changed, 680 insertions(+), 54 deletions(-) create mode 100644 furiosa/models/vision/yolov5/postprocess.py create mode 100644 tests/__init__.py create mode 100644 tests/accuracy/__init__.py create mode 100644 tests/accuracy/test_yolov5l.py create mode 100644 tests/accuracy/test_yolov5m.py diff --git a/.gitignore b/.gitignore index 2a44636c..8292ea6e 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,10 @@ cython_debug/ # IDE .idea .vscode + +# Benchmarks +.benchmarks/ + +# Test data (including symbolic links to it) +/tests/data +/tests/data/ \ No newline at end of file diff --git a/docs/models/yolov5l.md b/docs/models/yolov5l.md index 86b941be..1e4dd63f 100644 --- a/docs/models/yolov5l.md +++ b/docs/models/yolov5l.md @@ -33,7 +33,7 @@ The input is a 3-channel image of 640, 640 (height, width). * Optimal Batch Size (minimum: 1): <= 2 ## Outputs -The outputs are 3 `numpy.float32` tensors in various shapes as the following. +The outputs are 3 `numpy.float32` tensors in various shapes as the following. You can refer to `postprocess()` function to learn how to decode boxes, classes, and confidence scores. | Tensor | Shape | Data Type | Data Type | Description | @@ -41,20 +41,20 @@ You can refer to `postprocess()` function to learn how to decode boxes, classes, | 0 | (1, 45, 80, 80) | float32 | NCHW | | | 1 | (1, 45, 40, 40) | float32 | NCHW | | | 2 | (1, 45, 20, 20) | float32 | NCHW | | - + ## Pre/Postprocessing `furiosa.models.vision.YOLOv5l` class provides `preprocess` and `postprocess` methods. -`preprocess` method converts input images to input tensors, and `postprocess` method converts -model output tensors to a list of bounding boxes, scores and labels. +`preprocess` method converts input images to input tensors, and `postprocess` method converts +model output tensors to a list of bounding boxes, scores and labels. You can find examples at [YOLOv5l Usage](#YOLOv5l_Usage). - + ### `furiosa.models.vision.YOLOv5l.preprocess` ::: furiosa.models.vision.yolov5.core.YOLOv5PreProcessor.__call__ options: show_source: false - + ### `furiosa.models.vision.YOLOv5l.postprocess` -::: furiosa.models.vision.yolov5.core.YOLOv5PostProcessor.__call__ +::: furiosa.models.vision.yolov5.core.YOLOv5PythonPostProcessor.__call__ options: show_source: false diff --git a/docs/models/yolov5m.md b/docs/models/yolov5m.md index c4a6e500..85389a30 100644 --- a/docs/models/yolov5m.md +++ b/docs/models/yolov5m.md @@ -33,7 +33,7 @@ The input is a 3-channel image of 640, 640 (height, width). * Optimal Batch Size (minimum: 1): <= 4 ## Outputs -The outputs are 3 `numpy.float32` tensors in various shapes as the following. +The outputs are 3 `numpy.float32` tensors in various shapes as the following. You can refer to `postprocess()` function to learn how to decode boxes, classes, and confidence scores. | Tensor | Shape | Data Type | Data Type | Description | @@ -44,17 +44,17 @@ You can refer to `postprocess()` function to learn how to decode boxes, classes, ## Pre/Postprocessing `furiosa.models.vision.YOLOv5m` class provides `preprocess` and `postprocess` methods. -`preprocess` method converts input images to input tensors, and `postprocess` method converts -model output tensors to a list of bounding boxes, scores and labels. +`preprocess` method converts input images to input tensors, and `postprocess` method converts +model output tensors to a list of bounding boxes, scores and labels. You can find examples at [YOLOv5m Usage](#YOLOv5m_Usage). - + ### `furiosa.models.vision.YOLOv5m.preprocess` ::: furiosa.models.vision.yolov5.core.YOLOv5PreProcessor.__call__ options: show_source: false - + ### `furiosa.models.vision.YOLOv5m.postprocess` -::: furiosa.models.vision.yolov5.core.YOLOv5PostProcessor.__call__ +::: furiosa.models.vision.yolov5.core.YOLOv5PythonPostProcessor.__call__ options: show_source: false diff --git a/furiosa/models/__init__.py b/furiosa/models/__init__.py index 15c7e384..dc279bc5 100644 --- a/furiosa/models/__init__.py +++ b/furiosa/models/__init__.py @@ -1,4 +1,5 @@ """Furiosa Models""" + from . import errors, vision __version__ = "0.10.0.dev0" diff --git a/furiosa/models/types.py b/furiosa/models/types.py index 04f9bada..2eb38e29 100644 --- a/furiosa/models/types.py +++ b/furiosa/models/types.py @@ -47,14 +47,14 @@ class Format(str, Enum): class PreProcessor(ABC): @abstractmethod - def __call__(self, inputs: Any) -> Tuple[Sequence[npt.ArrayLike], Sequence[Context]]: - ... + def __call__(self, inputs: Any) -> Tuple[Sequence[npt.ArrayLike], Sequence[Context]]: ... class PostProcessor(ABC): @abstractmethod - def __call__(self, model_outputs: Sequence[npt.ArrayLike], contexts: Sequence[Context]) -> Any: - ... + def __call__( + self, model_outputs: Sequence[npt.ArrayLike], contexts: Sequence[Context] + ) -> Any: ... class RustPostProcessor(PostProcessor): diff --git a/furiosa/models/vision/yolov5/core.py b/furiosa/models/vision/yolov5/core.py index 364ed375..f288b76b 100644 --- a/furiosa/models/vision/yolov5/core.py +++ b/furiosa/models/vision/yolov5/core.py @@ -16,6 +16,7 @@ ) from ...vision.postprocess import LtrbBoundingBox, ObjectDetectionResult from ..preprocess import read_image_opencv_if_needed +from .postprocess import YOLOv5PythonPostProcessor _INPUT_SIZE = (640, 640) _STRIDES = [8, 16, 32] @@ -153,17 +154,22 @@ def __call__( return np.stack(batched_image, axis=0), batched_proc_params -class YOLOv5PostProcessor(RustPostProcessor): +def sigmoid(x: np.ndarray) -> np.ndarray: + # pylint: disable=invalid-name + return 1 / (1 + np.exp(-x)) + + +class YOLOv5NativePostProcessor(RustPostProcessor): def __init__(self, anchors: npt.ArrayLike, class_names: Sequence[str]): """ - native (RustProcessor): A native postprocessor. It has several information to decode: (xyxy, - confidence threshold, anchor_grid, stride, number of classes). - class_names (Sequence[str]): A list of class names. + Args: + anchors (npt.ArrayLike): A list of anchors. + class_names (Sequence[str]): A list of class names. """ self.anchors = anchors self.class_names = class_names self.anchor_per_layer_count = anchors.shape[1] - self.native = native.yolov5.RustPostProcessor(anchors, _STRIDES) + self.native = native.yolo.RustPostProcessor(anchors, _STRIDES) def __call__( self, @@ -171,6 +177,7 @@ def __call__( contexts: Sequence[Dict[str, Any]], conf_thres: float = 0.25, iou_thres: float = 0.45, + with_sigmoid: bool = False, ) -> List[List[ObjectDetectionResult]]: """Convert the outputs of this model to a list of bounding boxes, scores and labels @@ -184,6 +191,8 @@ def __call__( and height. conf_thres: Confidence score threshold. The default to 0.25 iou_thres: IoU threshold value for the NMS processing. The default to 0.45. + with_sigmoid: Whether to apply sigmoid function to the model outputs. The default to + False. Returns: Detected Bounding Box and its score and label represented as `ObjectDetectionResult`. @@ -203,7 +212,10 @@ def __call__( for f in model_outputs ] - batched_boxes = self.native.eval(model_outputs, conf_thres, iou_thres) + if with_sigmoid: + model_outputs = sigmoid(model_outputs) + + batched_boxes = self.native.eval(model_outputs, conf_thres, iou_thres, None, None) batched_detected_boxes = [] for boxes, preproc_params in zip(batched_boxes, contexts): @@ -213,16 +225,18 @@ def __call__( # rescale boxes for box in boxes: + left, top, right, bottom, score, class_id = box + class_id = int(class_id) detected_boxes.append( ObjectDetectionResult( - index=box.class_id, - label=self.class_names[box.class_id], - score=box.score, + index=class_id, + label=self.class_names[class_id], + score=score, boundingbox=LtrbBoundingBox( - left=(box.left - padw) / scale, - top=(box.top - padh) / scale, - right=(box.right - padw) / scale, - bottom=(box.bottom - padh) / scale, + left=(left - padw) / scale, + top=(top - padh) / scale, + right=(right - padw) / scale, + bottom=(bottom - padh) / scale, ), ) ) @@ -233,7 +247,8 @@ def __call__( class YOLOv5Base(ObjectDetectionModel, ABC): postprocessor_map: ClassVar[Dict[Platform, Type[PostProcessor]]] = { - Platform.RUST: YOLOv5PostProcessor, + Platform.PYTHON: YOLOv5PythonPostProcessor, + Platform.RUST: YOLOv5NativePostProcessor, } def __init__(self, *args, **kwargs): diff --git a/furiosa/models/vision/yolov5/large.py b/furiosa/models/vision/yolov5/large.py index 25f47ddf..6b29a2c4 100644 --- a/furiosa/models/vision/yolov5/large.py +++ b/furiosa/models/vision/yolov5/large.py @@ -3,6 +3,7 @@ Attributes: CLASSES (List[str]): a list of class names """ + import pathlib from typing import List, Union @@ -24,9 +25,13 @@ class YOLOv5l(YOLOv5Base): """YOLOv5 Large model""" - classes: List[str] = CLASSES - - def __init__(self, *, postprocessor_type: Union[str, Platform] = Platform.RUST): + def __init__( + self, + *, + postprocessor_type: Union[str, Platform] = Platform.RUST, + classes: List[str] = CLASSES, + anchors: np.array = _ANCHORS, + ): postprocessor_type = Platform(postprocessor_type) validate_postprocessor_type(postprocessor_type, self.postprocessor_map.keys()) super().__init__( @@ -35,7 +40,7 @@ def __init__(self, *, postprocessor_type: Union[str, Platform] = Platform.RUST): description="YOLOv5 large model", publication=Publication(url="https://github.com/ultralytics/yolov5"), ), - postprocessor=self.postprocessor_map[postprocessor_type](_ANCHORS, CLASSES), + postprocessor=self.postprocessor_map[postprocessor_type](anchors, classes), ) self._artifact_name = "yolov5l" diff --git a/furiosa/models/vision/yolov5/medium.py b/furiosa/models/vision/yolov5/medium.py index 3e735d55..d5ce23e1 100644 --- a/furiosa/models/vision/yolov5/medium.py +++ b/furiosa/models/vision/yolov5/medium.py @@ -3,6 +3,7 @@ Attributes: CLASSES (List[str]): a list of class names """ + import pathlib from typing import List, Union @@ -24,9 +25,13 @@ class YOLOv5m(YOLOv5Base): """YOLOv5 Medium model""" - classes: List[str] = CLASSES - - def __init__(self, *, postprocessor_type: Union[str, Platform] = Platform.RUST): + def __init__( + self, + *, + postprocessor_type: Union[str, Platform] = Platform.RUST, + classes: List[str] = CLASSES, + anchors: np.array = _ANCHORS, + ): postprocessor_type = Platform(postprocessor_type) validate_postprocessor_type(postprocessor_type, self.postprocessor_map.keys()) super().__init__( @@ -35,7 +40,7 @@ def __init__(self, *, postprocessor_type: Union[str, Platform] = Platform.RUST): description="YOLOv5 medium model", publication=Publication(url="https://github.com/ultralytics/yolov5"), ), - postprocessor=self.postprocessor_map[postprocessor_type](_ANCHORS, CLASSES), + postprocessor=self.postprocessor_map[postprocessor_type](anchors, classes), ) self._artifact_name = "yolov5m" diff --git a/furiosa/models/vision/yolov5/postprocess.py b/furiosa/models/vision/yolov5/postprocess.py new file mode 100644 index 00000000..75d92cb6 --- /dev/null +++ b/furiosa/models/vision/yolov5/postprocess.py @@ -0,0 +1,228 @@ +from typing import Any, Dict, Sequence + +import numpy as np +import torch +import torchvision + +from ...types import PythonPostProcessor +from ..postprocess import LtrbBoundingBox, ObjectDetectionResult + + +def _reshape_output(feat: np.ndarray, anchor_per_layer_count: int, num_classes: int): + return feat.reshape( + feat.shape[0], # batch + anchor_per_layer_count, + num_classes + 5, # boundingbox(4) + objectness score + classes score of that object + feat.shape[2], # the number of width grid + feat.shape[3], # the number of height grid + ).transpose(0, 1, 3, 4, 2) + + +class YOLOv5PythonPostProcessor(PythonPostProcessor): + def __init__(self, anchors, class_names, input_shape=(640, 640)): + self.anchors = anchors + self.class_names = class_names + self.input_shape = input_shape + self.num_layers = anchors.shape[0] + self.anchor_per_layer_count = anchors.shape[1] + self.stride = np.array( + [8.0 * pow(2, i) for i in range(self.num_layers)], + dtype=np.float32, + ) + self.grid, self.anchor_grid = self.init_grid() + + def __call__( + self, + model_outputs: Sequence[np.ndarray], + contexts: Sequence[Dict[str, Any]], + conf_thres: float = 0.25, + iou_thres: float = 0.45, + with_sigmoid: bool = False, + ): + """Convert the outputs of this model to a list of bounding boxes, scores and labels + + Args: + model_outputs: P3/8, P4/16, P5/32 features from yolov5l model. + To learn more about the outputs of preprocess (i.e., model inputs), + please refer to [YOLOv5l Outputs](yolov5l.md#outputs) or + [YOLOv5m Outputs](yolov5m.md#outputs). + contexts: A configuration for each image generated by the preprocessor. + For example, it could be the reduction ratio of the image, the actual image width + and height. + conf_thres: Confidence score threshold. The default to 0.25 + iou_thres: IoU threshold value for the NMS processing. The default to 0.45. + with_sigmoid: Whether to apply sigmoid function to the model outputs. The default to + False. + + Returns: + Detected Bounding Box and its score and label represented as `ObjectDetectionResult`. + The details of `ObjectDetectionResult` can be found below. + + Definition of ObjectDetectionResult and LtrbBoundingBox: + ::: furiosa.models.vision.postprocess.LtrbBoundingBox + options: + show_source: true + ::: furiosa.models.vision.postprocess.ObjectDetectionResult + options: + show_source: true + """ + outputs = [] + model_outputs = [ + _reshape_output(f, self.anchor_per_layer_count, len(self.class_names)) + for f in model_outputs + ] + for model_output, grid, stride, anchor_grid in zip( + model_outputs, self.grid, self.stride, self.anchor_grid + ): + _, _, nx, ny, _ = model_output.shape + if with_sigmoid: + model_output = sigmoid(model_output) + xy, wh, conf = np.split(model_output, [2, 4], axis=4) + xy = (xy * 2 + grid) * stride + wh = (wh * 2) ** 2 * anchor_grid + y = np.concatenate((xy, wh, conf), axis=4) + outputs.append( + np.reshape( + y, + ( + 1, + self.anchor_per_layer_count * nx * ny, + len(self.class_names) + 5, + ), + ) + ) + outputs = np.concatenate(outputs, axis=1) + model_outputs = non_max_suppression(outputs, conf_thres, iou_thres) + + batched_detected_boxes = [] + for boxes, preproc_params in zip(model_outputs, contexts): + scale = preproc_params['scale'] + padw, padh = preproc_params['pad'] + detected_boxes = [] + # rescale boxes + + for box in boxes: + left, top, right, bottom, score, class_id = box + class_id = int(class_id) + detected_boxes.append( + ObjectDetectionResult( + index=class_id, + label=self.class_names[class_id], + score=score, + boundingbox=LtrbBoundingBox( + left=(left - padw) / scale, + top=(top - padh) / scale, + right=(right - padw) / scale, + bottom=(bottom - padh) / scale, + ), + ) + ) + batched_detected_boxes.append(detected_boxes) + + return batched_detected_boxes + + def init_grid(self): + grid = [np.zeros(1)] * self.num_layers + anchor_grid = [np.zeros(1)] * self.num_layers + + nx_ny = [ + ( + int(self.input_shape[0] / (8 * pow(2, i))), + int(self.input_shape[1] / (8 * pow(2, i))), + ) + for i in range(self.num_layers) + ] + for i in range(self.num_layers): + grid[i], anchor_grid[i] = self.make_grid(nx_ny[i][0], nx_ny[i][1], i) + + return grid, anchor_grid + + def make_grid(self, nx: int, ny: int, i: int): + shape = 1, self.anchor_per_layer_count, ny, nx, 2 + y, x = np.arange(ny, dtype=np.float32), np.arange(nx, dtype=np.float32) + yv, xv = np.meshgrid(y, x, indexing="ij") + grid = np.broadcast_to(np.stack((xv, yv), axis=2), shape) - 0.5 + anchor_grid = np.broadcast_to( + np.reshape( + self.anchors[i] * self.stride[i], + (1, self.anchor_per_layer_count, 1, 1, 2), + ), + shape, + ) + return grid, anchor_grid + + +def sigmoid(x: np.ndarray) -> np.ndarray: + # pylint: disable=invalid-name + return 1 / (1 + np.exp(-x)) + + +# https://github.com/ultralytics/yolov5/blob/v7.0/utils/general.py#L884-L999 +def non_max_suppression( + prediction: np.ndarray, + conf_thres: float, + iou_thres: float, + agnostic: bool = False, +): + # pylint: disable=invalid-name,too-many-locals + + batch_size = prediction.shape[0] + candidates = prediction[..., 4] > conf_thres + assert 0 <= conf_thres <= 1, conf_thres + assert 0 <= iou_thres <= 1, iou_thres + + max_wh = 7680 # (pixels) maximum box width and height + max_nms = 30000 + max_det: int = 300 + + output = [np.empty((0, 6))] * batch_size + for xi, x in enumerate(prediction): + x = x[candidates[xi]] + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box/Mask + box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2) + + i, j = np.where(x[:, 5:] > conf_thres) + x = np.concatenate( + ( + box[i], + x[i, j + 5, np.newaxis].astype(np.float32), + j[:, np.newaxis].astype(np.float32), + ), + axis=1, + ) + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + x = x[np.argsort(x[:, 4])[::-1][:max_nms]] # sort by confidence and remove excess boxes + + # NMS + classes = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + classes, x[:, 4] # boxes (offset by class), scores + + i = torchvision.ops.nms( + torch.from_numpy(boxes), torch.from_numpy(scores), iou_thres + ).numpy() + i = i[:max_det] # limit detections + + output[xi] = x[i] + + return output + + +# https://github.com/ultralytics/yolov5/blob/v7.0/utils/general.py#L760-L767 +def xywh2xyxy(x: np.ndarray) -> np.ndarray: + # pylint: disable=invalid-name + y = np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 + y[:, 1] = x[:, 1] - x[:, 3] / 2 + y[:, 2] = x[:, 0] + x[:, 2] / 2 + y[:, 3] = x[:, 1] + x[:, 3] / 2 + return y diff --git a/pyproject.toml b/pyproject.toml index fc42bc80..a94c3ccd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "furiosa-common == 0.10.*", "furiosa-device == 0.10.*", "furiosa-runtime == 0.10.*", - "furiosa-native-postprocess == 0.9.0", + "furiosa-native-postprocess == 0.10.0", "PyYAML", "numpy", @@ -109,7 +109,7 @@ force_sort_within_sections = true known_first_party = ["furiosa"] line_length = 100 profile = "black" -extend_skip_glob = ["**/generated/**", "tests/assets/**", ".dvc/**"] +extend_skip_glob = ["**/generated/**", "tests/data/**", ".dvc/**"] [tool.ruff] line-length = 100 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/accuracy/__init__.py b/tests/accuracy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/accuracy/test_yolov5l.py b/tests/accuracy/test_yolov5l.py new file mode 100644 index 00000000..d371b6f0 --- /dev/null +++ b/tests/accuracy/test_yolov5l.py @@ -0,0 +1,133 @@ +import asyncio +import itertools +import os +from pathlib import Path +from typing import Tuple + +import tqdm + +from furiosa.models.vision import YOLOv5l +from furiosa.runtime import create_runner + +from ..bench.test_acc_util import bdd100k + +EXPECTED_MAP = 0.2894519996558422 +EXPECTED_MAP_RUST = 0.2895171558198357 + +CONF_THRES = 0.001 +IOU_THRES = 0.45 + + +def load_db_from_env_variable() -> Tuple[Path, bdd100k.Yolov5Dataset]: + MUST_10K_LIMIT = 10_000 + databaset_path = Path(os.environ.get('YOLOV5_DATASET_PATH', "./tests/data/bdd100k_val")) + + db = bdd100k.Yolov5Dataset(databaset_path, mode="val", limit=MUST_10K_LIMIT) + + return databaset_path, db + + +async def test_yolov5l_accuracy(): + model: YOLOv5l = YOLOv5l(postprocessor_type="Python") + + image_directory, yolov5db = load_db_from_env_variable() + + print(f"dataset_path: {image_directory}") + metric = bdd100k.MAPMetricYolov5(num_classes=10) + + num_images = len(yolov5db) + yolov5db = iter(yolov5db) + + async def workload(im): + batch_im = [im] + + batch_pre_img, batch_preproc_param = model.preprocess(batch_im) + batch_feat = await runner.run(batch_pre_img) + detected_boxes = model.postprocess( + batch_feat, batch_preproc_param, conf_thres=CONF_THRES, iou_thres=IOU_THRES + ) + + return bdd100k.to_numpy(detected_boxes[0]) + + async with create_runner(model.model_source(num_pe=1), device="warboy(1)*1") as runner: + steps = 10 + assert num_images % steps == 0, "cannot divide by step" + iters = num_images // steps + for _ in tqdm.tqdm(range(steps), desc="yolov5l accuracy w/ python pp"): + worklist = [] + bxtargets = [] + clstargets = [] + for im, boxes_target, classes_target in itertools.islice(yolov5db, iters): + bxtargets.append(boxes_target) + clstargets.append(classes_target) + worklist.append(workload(im)) + for det_out, boxes_target, classes_target in zip( + await asyncio.gather(*worklist), bxtargets, clstargets + ): + metric( + boxes_pred=det_out[:, :4], + scores_pred=det_out[:, 4], + classes_pred=det_out[:, 5], + boxes_target=boxes_target, + classes_target=classes_target, + ) + + result = metric.compute() + print("YOLOv5Large mAP:", result['map']) + print("YOLOv5Large mAP50:", result['map50']) + print("YOLOv5Large ap_class:", result['ap_class']) + print("YOLOv5Large ap50_class:", result['ap50_class']) + assert result['map'] == EXPECTED_MAP, "Accuracy check w/ python failed" + + +async def test_yolov5l_with_native_rust_pp_accuracy(): + model: YOLOv5l = YOLOv5l(postprocessor_type="Rust") + + image_directory, yolov5db = load_db_from_env_variable() + + print(f"dataset_path: {image_directory}") + metric = bdd100k.MAPMetricYolov5(num_classes=10) + + num_images = len(yolov5db) + yolov5db = iter(yolov5db) + + async def workload(im): + batch_im = [im] + + batch_pre_img, batch_preproc_param = model.preprocessor.__call__(batch_im) + batch_feat = await runner.run(batch_pre_img) + detected_boxes = model.postprocessor.__call__( + batch_feat, batch_preproc_param, conf_thres=CONF_THRES, iou_thres=IOU_THRES + ) + + return bdd100k.to_numpy(detected_boxes[0]) + + async with create_runner(model.model_source(num_pe=1), device="warboy(1)*1") as runner: + steps = 10 + assert num_images % steps == 0, "cannot divide by step" + iters = num_images // steps + for _ in tqdm.tqdm(range(steps), desc="yolov5l accuracy w/ rust pp"): + worklist = [] + bxtargets = [] + clstargets = [] + for im, boxes_target, classes_target in itertools.islice(yolov5db, iters): + bxtargets.append(boxes_target) + clstargets.append(classes_target) + worklist.append(workload(im)) + for det_out, boxes_target, classes_target in zip( + await asyncio.gather(*worklist), bxtargets, clstargets + ): + metric( + boxes_pred=det_out[:, :4], + scores_pred=det_out[:, 4], + classes_pred=det_out[:, 5], + boxes_target=boxes_target, + classes_target=classes_target, + ) + + result = metric.compute() + print("YOLOv5Large mAP:", result['map']) + print("YOLOv5Large mAP50:", result['map50']) + print("YOLOv5Large ap_class:", result['ap_class']) + print("YOLOv5Large ap50_class:", result['ap50_class']) + assert result['map'] == EXPECTED_MAP_RUST, "Accuracy check w/ rust failed" diff --git a/tests/accuracy/test_yolov5m.py b/tests/accuracy/test_yolov5m.py new file mode 100644 index 00000000..e3aa00c8 --- /dev/null +++ b/tests/accuracy/test_yolov5m.py @@ -0,0 +1,135 @@ +import asyncio +import itertools +import os +from pathlib import Path +from typing import Tuple + +import tqdm + +from furiosa.models.vision import YOLOv5m +from furiosa.runtime import create_runner + +from ..bench.test_acc_util import bdd100k + +EXPECTED_MAP = 0.27702783413351617 +EXPECTED_MAP_RUST = 0.2769884679629229 + +CONF_THRES = 0.001 +IOU_THRES = 0.45 + + +def load_db_from_env_variable() -> Tuple[Path, bdd100k.Yolov5Dataset]: + MUST_10K_LIMIT = 10_000 + databaset_path = Path(os.environ.get('YOLOV5_DATASET_PATH', "./tests/data/bdd100k_val")) + + db = bdd100k.Yolov5Dataset(databaset_path, mode="val", limit=MUST_10K_LIMIT) + + return databaset_path, db + + +async def test_yolov5m_accuracy(): + model: YOLOv5m = YOLOv5m(postprocessor_type="Python") + + image_directory, yolov5db = load_db_from_env_variable() + + print(f"dataset_path: {image_directory}") + metric = bdd100k.MAPMetricYolov5(num_classes=10) + + num_images = len(yolov5db) + yolov5db = iter(yolov5db) + + async def workload(im): + batch_im = [im] + + batch_pre_img, batch_preproc_param = model.preprocess(batch_im) + batch_feat = await runner.run(batch_pre_img) + detected_boxes = model.postprocess( + batch_feat, batch_preproc_param, conf_thres=CONF_THRES, iou_thres=IOU_THRES + ) + + return bdd100k.to_numpy(detected_boxes[0]) + + async with create_runner(model.model_source(num_pe=1), device="warboy(1)*1") as runner: + steps = 10 + assert num_images % steps == 0, "cannot divide by step" + iters = num_images // steps + for _ in tqdm.tqdm(range(steps), desc="yolov5m accuracy w/ python pp"): + worklist = [] + bxtargets = [] + clstargets = [] + for im, boxes_target, classes_target in itertools.islice(yolov5db, iters): + bxtargets.append(boxes_target) + clstargets.append(classes_target) + worklist.append(workload(im)) + for det_out, boxes_target, classes_target in zip( + await asyncio.gather(*worklist), bxtargets, clstargets + ): + metric( + boxes_pred=det_out[:, :4], + scores_pred=det_out[:, 4], + classes_pred=det_out[:, 5], + boxes_target=boxes_target, + classes_target=classes_target, + ) + + result = metric.compute() + print("YOLOv5Medium mAP:", result['map']) + print("YOLOv5Medium mAP50:", result['map50']) + print("YOLOv5Medium ap_class:", result['ap_class']) + print("YOLOv5Medium ap50_class:", result['ap50_class']) + + assert result['map'] == EXPECTED_MAP, "Accuracy check w/ python failed" + + +async def test_yolov5m_with_native_rust_pp_accuracy(): + model: YOLOv5m = YOLOv5m(postprocessor_type="Rust") + + image_directory, yolov5db = load_db_from_env_variable() + + print(f"dataset_path: {image_directory}") + metric = bdd100k.MAPMetricYolov5(num_classes=10) + + num_images = len(yolov5db) + yolov5db = iter(yolov5db) + + async def workload(im): + batch_im = [im] + + batch_pre_img, batch_preproc_param = model.preprocessor.__call__(batch_im) + batch_feat = await runner.run(batch_pre_img) + detected_boxes = model.postprocessor.__call__( + batch_feat, batch_preproc_param, conf_thres=CONF_THRES, iou_thres=IOU_THRES + ) + + return bdd100k.to_numpy(detected_boxes[0]) + + async with create_runner(model.model_source(num_pe=1), device="warboy(1)*1") as runner: + steps = 10 + assert num_images % steps == 0, "cannot divide by step" + iters = num_images // steps + for _ in tqdm.tqdm(range(steps), desc="yolov5m accuracy w/ rust pp"): + worklist = [] + bxtargets = [] + clstargets = [] + for im, boxes_target, classes_target in itertools.islice(yolov5db, iters): + bxtargets.append(boxes_target) + clstargets.append(classes_target) + worklist.append(workload(im)) + for det_out, boxes_target, classes_target in zip( + await asyncio.gather(*worklist), bxtargets, clstargets + ): + metric( + boxes_pred=det_out[:, :4], + scores_pred=det_out[:, 4], + classes_pred=det_out[:, 5], + boxes_target=boxes_target, + classes_target=classes_target, + ) + + result = metric.compute() + print("YOLOv5Medium mAP:", result['map']) + print("YOLOv5Medium mAP50:", result['map50']) + print("YOLOv5Medium ap_class:", result['ap_class']) + print("YOLOv5Medium ap50_class:", result['ap50_class']) + + assert result['map'] == EXPECTED_MAP_RUST, "Accuracy check w/ rust failed" diff --git a/tests/bench/test_ssd_mobilenet.py b/tests/bench/test_ssd_mobilenet.py index 4bab6d66..6219d50b 100644 --- a/tests/bench/test_ssd_mobilenet.py +++ b/tests/bench/test_ssd_mobilenet.py @@ -12,7 +12,7 @@ from furiosa.runtime.sync import create_runner EXPECTED_ACCURACY = 0.2319698092633901 -EXPECTED_ACCURACY_NATIVE_RUST_PP = 0.23178397430922199 +EXPECTED_ACCURACY_NATIVE_RUST_PP = 0.23178929362182082 def load_coco_from_env_variable(): diff --git a/tests/bench/test_yolov5l.py b/tests/bench/test_yolov5l.py index 73ef9736..b1344ac4 100644 --- a/tests/bench/test_yolov5l.py +++ b/tests/bench/test_yolov5l.py @@ -10,7 +10,11 @@ from .test_acc_util import bdd100k -EXPECTED_MAP = 0.28385080080789205 +EXPECTED_MAP = 0.2894519996558422 +EXPECTED_MAP_RUST = 0.2895171558198357 + +CONF_THRES = 0.001 +IOU_THRES = 0.45 def load_db_from_env_variable() -> Tuple[Path, bdd100k.Yolov5Dataset]: @@ -23,12 +27,12 @@ def load_db_from_env_variable() -> Tuple[Path, bdd100k.Yolov5Dataset]: def test_yolov5l_accuracy(benchmark): - model: YOLOv5l = YOLOv5l() + model: YOLOv5l = YOLOv5l(postprocessor_type="Python") image_directory, yolov5db = load_db_from_env_variable() print(f"dataset_path: {image_directory}") - metric = bdd100k.MAPMetricYolov5(num_classes=len(model.classes)) + metric = bdd100k.MAPMetricYolov5(num_classes=10) num_images = len(yolov5db) yolov5db = iter(tqdm(yolov5db)) @@ -45,7 +49,7 @@ def workload(im, boxes_target, classes_target): ) # single-batch batch_feat = runner.run(np.expand_dims(batch_pre_img[0], axis=0)) detected_boxes = model.postprocess( - batch_feat, batch_preproc_param, conf_thres=0.001, iou_thres=0.6 + batch_feat, batch_preproc_param, conf_thres=CONF_THRES, iou_thres=IOU_THRES ) det_out = bdd100k.to_numpy(detected_boxes[0]) metric( @@ -65,3 +69,48 @@ def workload(im, boxes_target, classes_target): print("YOLOv5Large ap_class:", result['ap_class']) print("YOLOv5Large ap50_class:", result['ap50_class']) assert result['map'] == EXPECTED_MAP, "Accuracy check failed" + + +def test_yolov5l_accuracy_with_rust_pp(benchmark): + model: YOLOv5l = YOLOv5l(postprocessor_type="Rust") + + image_directory, yolov5db = load_db_from_env_variable() + + print(f"dataset_path: {image_directory}") + metric = bdd100k.MAPMetricYolov5(num_classes=10) + + num_images = len(yolov5db) + yolov5db = iter(tqdm(yolov5db)) + + def read_image(): + im, boxes_target, classes_target = next(yolov5db) + return (im, boxes_target, classes_target), {} + + def workload(im, boxes_target, classes_target): + batch_im = [im] + + batch_pre_img, batch_preproc_param = model.preprocess( + batch_im, + ) # single-batch + batch_feat = runner.run(np.expand_dims(batch_pre_img[0], axis=0)) + detected_boxes = model.postprocess( + batch_feat, batch_preproc_param, conf_thres=CONF_THRES, iou_thres=IOU_THRES + ) + det_out = bdd100k.to_numpy(detected_boxes[0]) + metric( + boxes_pred=det_out[:, :4], + scores_pred=det_out[:, 4], + classes_pred=det_out[:, 5], + boxes_target=boxes_target, + classes_target=classes_target, + ) + + with create_runner(model.model_source()) as runner: + benchmark.pedantic(workload, setup=read_image, rounds=num_images) + + result = metric.compute() + print("YOLOv5Large mAP:", result['map']) + print("YOLOv5Large mAP50:", result['map50']) + print("YOLOv5Large ap_class:", result['ap_class']) + print("YOLOv5Large ap50_class:", result['ap50_class']) + assert result['map'] == EXPECTED_MAP_RUST, "Accuracy check failed" diff --git a/tests/bench/test_yolov5m.py b/tests/bench/test_yolov5m.py index 4f763fcc..2f8e13b7 100644 --- a/tests/bench/test_yolov5m.py +++ b/tests/bench/test_yolov5m.py @@ -10,7 +10,11 @@ from .test_acc_util import bdd100k -EXPECTED_MAP = 0.2716221365849332 +EXPECTED_MAP = 0.27702783413351617 +EXPECTED_MAP_RUST = 0.2769884679629229 + +CONF_THRES = 0.001 +IOU_THRES = 0.45 def load_db_from_env_variable() -> Tuple[Path, bdd100k.Yolov5Dataset]: @@ -23,12 +27,12 @@ def load_db_from_env_variable() -> Tuple[Path, bdd100k.Yolov5Dataset]: def test_yolov5m_accuracy(benchmark): - model: YOLOv5m = YOLOv5m() + model: YOLOv5m = YOLOv5m(postprocessor_type="Python") image_directory, yolov5db = load_db_from_env_variable() print(f"dataset_path: {image_directory}") - metric = bdd100k.MAPMetricYolov5(num_classes=len(model.classes)) + metric = bdd100k.MAPMetricYolov5(num_classes=10) num_images = len(yolov5db) yolov5db = iter(tqdm(yolov5db)) @@ -45,7 +49,7 @@ def workload(im, boxes_target, classes_target): ) # single-batch batch_feat = runner.run(np.expand_dims(batch_pre_img[0], axis=0)) detected_boxes = model.postprocess( - batch_feat, batch_preproc_param, conf_thres=0.001, iou_thres=0.6 + batch_feat, batch_preproc_param, conf_thres=CONF_THRES, iou_thres=IOU_THRES ) det_out = bdd100k.to_numpy(detected_boxes[0]) @@ -68,3 +72,51 @@ def workload(im, boxes_target, classes_target): print("YOLOv5Medium ap50_class:", result['ap50_class']) assert result['map'] == EXPECTED_MAP, "Accuracy check failed" + + +def test_yolov5m_with_native_rust_pp_accuracy(benchmark): + model: YOLOv5m = YOLOv5m(postprocessor_type="Rust") + + image_directory, yolov5db = load_db_from_env_variable() + + print(f"dataset_path: {image_directory}") + metric = bdd100k.MAPMetricYolov5(num_classes=10) + + num_images = len(yolov5db) + yolov5db = iter(tqdm(yolov5db)) + + def read_image(): + im, boxes_target, classes_target = next(yolov5db) + return (im, boxes_target, classes_target), {} + + def workload(im, boxes_target, classes_target): + batch_im = [im] + + batch_pre_img, batch_preproc_param = model.preprocess( + batch_im, + ) # single-batch + batch_feat = runner.run(np.expand_dims(batch_pre_img[0], axis=0)) + detected_boxes = model.postprocess( + batch_feat, batch_preproc_param, conf_thres=CONF_THRES, iou_thres=IOU_THRES + ) + + det_out = bdd100k.to_numpy(detected_boxes[0]) + + metric( + boxes_pred=det_out[:, :4], + scores_pred=det_out[:, 4], + classes_pred=det_out[:, 5], + boxes_target=boxes_target, + classes_target=classes_target, + ) + + with create_runner(model.model_source()) as runner: + benchmark.pedantic(workload, setup=read_image, rounds=num_images) + + result = metric.compute() + print("YOLOv5Medium mAP:", result['map']) + print("YOLOv5Medium mAP50:", result['map50']) + print("YOLOv5Medium ap_class:", result['ap_class']) + print("YOLOv5Medium ap50_class:", result['ap50_class']) + + assert result['map'] == EXPECTED_MAP_RUST, "Accuracy check failed" diff --git a/tests/unit/test_batched_yolov5l.py b/tests/unit/test_batched_yolov5l.py index 5cbb6656..f839fc04 100644 --- a/tests/unit/test_batched_yolov5l.py +++ b/tests/unit/test_batched_yolov5l.py @@ -9,14 +9,12 @@ TEST_IMAGE_PATH = str(Path(__file__).parent / "../assets/yolov5-test.jpg") -NUM_CLASSES = 10 NUM_BATCHES = 2 NUM_DETECTED_BOXES = 26 def test_yolov5_large_batched(): m = YOLOv5l() - assert len(m.classes) == NUM_CLASSES, "expected CLASS is 10" batch_im = [cv2.imread(TEST_IMAGE_PATH), cv2.imread(TEST_IMAGE_PATH)] with create_runner(m.model_source()) as runner: diff --git a/tests/unit/test_batched_yolov5m.py b/tests/unit/test_batched_yolov5m.py index 363ee766..e6774604 100644 --- a/tests/unit/test_batched_yolov5m.py +++ b/tests/unit/test_batched_yolov5m.py @@ -9,14 +9,12 @@ TEST_IMAGE_PATH = str(Path(__file__).parent / "../assets/yolov5-test.jpg") -NUM_CLASSES = 10 NUM_BATCHES = 2 NUM_DETECTED_BOXES = 21 def test_yolov5_medium_batched(): m = YOLOv5m() - assert len(m.classes) == NUM_CLASSES, "expected CLASS is 10" batch_im = [cv2.imread(TEST_IMAGE_PATH), cv2.imread(TEST_IMAGE_PATH)] with create_runner(m.model_source()) as runner: