Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Add TorchDetectionPredictor #32199

Merged
merged 25 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3203f0c
Initial commit
bveeramani Jan 26, 2023
28fa37c
Address review comments
bveeramani Jan 30, 2023
2f5fd21
Update utils.py
bveeramani Jan 31, 2023
679f3b1
Merge remote-tracking branch 'upstream/master' into promote-create-ra…
bveeramani Jan 31, 2023
e3a7630
Merge remote-tracking branch 'upstream/master' into promote-create-ra…
bveeramani Jan 31, 2023
10d78cf
Address review comments
bveeramani Jan 31, 2023
0d331c5
Update tensor_extension.py
bveeramani Jan 31, 2023
328b3f8
Update tensor_extension.py
bveeramani Jan 31, 2023
694764e
Merge remote-tracking branch 'upstream/master' into promote-create-ra…
bveeramani Feb 2, 2023
9855e94
Initial commit
bveeramani Feb 2, 2023
bbbda1f
Merge branch 'dtype-optional' into detection-predictor
bveeramani Feb 2, 2023
22b3aa9
Add `TorchDetectionPredictor`
bveeramani Feb 3, 2023
ab67b37
Fix test
bveeramani Feb 6, 2023
750cc0b
Address review comments
bveeramani Feb 6, 2023
edf14b8
Merge branch 'dtype-optional' into detection-predictor
bveeramani Feb 6, 2023
43693e5
Merge branch 'promote-create-ragged' into detection-predictor
bveeramani Feb 6, 2023
dd3a9eb
Address review comments
bveeramani Feb 7, 2023
6ecf830
Update torch_detection_predictor.py
bveeramani Feb 7, 2023
6374151
Fix Bazel
bveeramani Feb 8, 2023
c85f5d1
Merge remote-tracking branch 'upstream/master' into detection-predictor
bveeramani Feb 8, 2023
2f41527
Merge remote-tracking branch 'upstream/master' into detection-predictor
bveeramani Feb 8, 2023
bdf28f9
Update BUILD
bveeramani Feb 8, 2023
163d2af
Update torch_predictor.py
bveeramani Feb 8, 2023
a55bc0a
Merge remote-tracking branch 'upstream/master' into detection-predictor
bveeramani Feb 8, 2023
cea988f
Update torch_predictor.py
bveeramani Feb 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions doc/source/train/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,21 @@ PyTorch
``TorchPredictor``
******************

.. automodule:: ray.train.torch
.. autoclass:: ray.train.torch.TorchPredictor
:members:
:exclude-members: TorchTrainer
:show-inheritance:

.. automethod:: __init__

``TorchDetectionPredictor``
***************************

.. autoclass:: ray.train.torch.TorchDetectionPredictor
:members:
:show-inheritance:

.. automethod:: __init__

Horovod
~~~~~~~

Expand Down
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,14 @@ py_test(
deps = [":train_lib", ":conftest"]
)

py_test(
name = "test_torch_detection_predictor",
size = "small",
srcs = ["tests/test_torch_detection_predictor.py"],
tags = ["team:ml", "exclusive", "ray_air", "gpu"],
deps = [":train_lib", ":conftest"]
)

py_test(
name = "test_torch_trainer",
size = "medium",
Expand Down
118 changes: 118 additions & 0 deletions python/ray/train/tests/test_torch_detection_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import numpy as np
import pytest
from torchvision import models

import ray
from ray.air.util.tensor_extensions.utils import create_ragged_ndarray
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchDetectionPredictor


@pytest.fixture(name="predictor")
def predictor_fixture():
model = models.detection.maskrcnn_resnet50_fpn()
yield TorchDetectionPredictor(model=model)


@pytest.mark.parametrize(
"data",
[
np.zeros((1, 3, 32, 32), dtype=np.float32),
{"image": np.zeros((1, 3, 32, 32), dtype=np.float32)},
create_ragged_ndarray(
[
np.zeros((3, 32, 32), dtype=np.float32),
np.zeros((3, 64, 64), dtype=np.float32),
]
),
],
)
def test_predict(predictor, data):
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
predictions = predictor.predict(data)

assert all(len(value) == len(data) for value in predictions.values())
# Boxes should have shape `(# detections, 4)`.
assert all(boxes.ndim == 2 for boxes in predictions["pred_boxes"])
assert all(boxes.shape[-1] == 4 for boxes in predictions["pred_boxes"])
# Labels should have shape `(# detections,)`.
assert all(labels.ndim == 1 for labels in predictions["pred_labels"])
# Scores should have shape `(# detections,)`.
assert all(scores.ndim == 1 for scores in predictions["pred_scores"])


def test_predict_tensor_dataset():
model = models.detection.maskrcnn_resnet50_fpn()
checkpoint = TorchCheckpoint.from_model(model)
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchDetectionPredictor)
dataset = ray.data.from_items([np.zeros((3, 32, 32), dtype=np.float32)])

predictions = predictor.predict(dataset)

# Boxes should have shape `(# detections, 4)`.
pred_boxes = [row["pred_boxes"] for row in predictions.take_all()]
assert all(boxes.ndim == 2 for boxes in pred_boxes)
assert all(boxes.shape[-1] == 4 for boxes in pred_boxes)
# Labels should have shape `(# detections,)`.
pred_labels = [row["pred_labels"] for row in predictions.take_all()]
assert all(labels.ndim == 1 for labels in pred_labels)
# Scores should have shape `(# detections,)`.
pred_scores = [row["pred_scores"] for row in predictions.take_all()]
assert all(scores.ndim == 1 for scores in pred_scores)


@pytest.mark.parametrize(
"items",
[
[{"image": np.zeros((3, 32, 32), dtype=np.float32)}],
[
{"image": np.zeros((3, 32, 32), dtype=np.float32)},
{"image": np.zeros((3, 64, 64), dtype=np.float32)},
],
],
)
def test_predict_tabular_dataset(items):
model = models.detection.maskrcnn_resnet50_fpn()
checkpoint = TorchCheckpoint.from_model(model)
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchDetectionPredictor)
dataset = ray.data.from_items(items)

predictions = predictor.predict(dataset)

assert predictions.count() == len(items)
# Boxes should have shape `(# detections, 4)`.
pred_boxes = [row["pred_boxes"] for row in predictions.take_all()]
assert all(boxes.ndim == 2 for boxes in pred_boxes)
assert all(boxes.shape[-1] == 4 for boxes in pred_boxes)
# Labels should have shape `(# detections,)`.
pred_labels = [row["pred_labels"] for row in predictions.take_all()]
assert all(labels.ndim == 1 for labels in pred_labels)
# Scores should have shape `(# detections,)`.
pred_scores = [row["pred_scores"] for row in predictions.take_all()]
assert all(scores.ndim == 1 for scores in pred_scores)


def test_multi_column_batch_raises_value_error(predictor):
data = {
"image": np.zeros((2, 3, 32, 32), dtype=np.float32),
"boxes": np.zeros((2, 0, 4), dtype=np.float32),
"labels": np.zeros((2, 0), dtype=np.int64),
}
with pytest.raises(ValueError):
# `data` should only contain one key. Otherwise, `TorchDetectionPredictor`
# doesn't know which column contains the input images.
predictor.predict(data)


def test_invalid_dtype_raises_value_error(predictor):
data = np.zeros((1, 3, 32, 32), dtype=np.float32)
with pytest.raises(ValueError):
# `dtype` should be a single `torch.dtype`.
predictor.predict(data, dtype=np.float32)


if __name__ == "__main__":
import sys

import pytest

sys.exit(pytest.main(["-v", "-x", __file__]))
4 changes: 3 additions & 1 deletion python/ray/train/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
)
# isort: on

from ray.train.torch.torch_checkpoint import TorchCheckpoint
from ray.train.torch.config import TorchConfig
from ray.train.torch.torch_checkpoint import TorchCheckpoint
from ray.train.torch.torch_detection_predictor import TorchDetectionPredictor
from ray.train.torch.torch_predictor import TorchPredictor
from ray.train.torch.torch_trainer import TorchTrainer
from ray.train.torch.train_loop_utils import (
Expand All @@ -33,4 +34,5 @@
"backward",
"enable_reproducibility",
"TorchPredictor",
"TorchDetectionPredictor",
]
140 changes: 140 additions & 0 deletions python/ray/train/torch/torch_detection_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import collections
from typing import Dict, List, Optional, Union

import numpy as np
import torch

from ray.air.util.tensor_extensions.utils import create_ragged_ndarray
from ray.train._internal.dl_predictor import TensorDtype
from ray.train.torch.torch_predictor import TorchPredictor
from ray.util.annotations import PublicAPI


@PublicAPI(stability="alpha")
class TorchDetectionPredictor(TorchPredictor):
"""A predictor for TorchVision detection models.

Unlike other Torch models, instance segmentation models return
`List[Dict[str, Tensor]]`. This predictor extends :class:`TorchPredictor` to support
the non-standard outputs.

To learn more about instance segmentation models, read
`Instance segmentation models <https://pytorch.org/vision/main/auto_examples/plot_visualization_utils.html#instance-seg-output>`_.

Example:

.. testcode::

import numpy as np
from torchvision import models

from ray.train.torch import TorchDetectionPredictor

model = models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=True)

predictor = TorchDetectionPredictor(model=model)
predictions = predictor.predict(np.zeros((4, 3, 32, 32), dtype=np.float32))

print(predictions.keys())

.. testoutput::

dict_keys(['pred_boxes', 'pred_labels', 'pred_scores'])

.. testcode::

import numpy as np
from torchvision import models

import ray
from ray.train.batch_predictor import BatchPredictor
from ray.train.torch import TorchCheckpoint, TorchDetectionPredictor

dataset = ray.data.from_items([{"image": np.zeros((3, 32, 32), dtype=np.float32)}])
model = models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=True)
checkpoint = TorchCheckpoint.from_model(model)
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchDetectionPredictor)
predictions = predictor.predict(dataset, feature_columns=["image"])

print(predictions.take(1))

.. testoutput::

[{'pred_boxes': array([], shape=(0, 4), dtype=float32), 'pred_labels': array([], dtype=int64), 'pred_scores': array([], dtype=float32)}]
""" # noqa: E501

def _predict_numpy(
self,
data: Union[np.ndarray, Dict[str, np.ndarray]],
dtype: Optional[Union[TensorDtype, Dict[str, TensorDtype]]],
) -> Dict[str, np.ndarray]:
if isinstance(data, dict) and len(data) != 1:
raise ValueError(
f"""Expected input to contain one key, but got {len(data)} instead.

If you're using `BatchPredictor`, pass a one-element list to
`feature_columns`.

---
predictor = BatchPredictor(checkpoint, TorchDetectionPredictor)
predictor.predict(dataset, feature_columns=["image"])
---
"""
)

if dtype is not None and not isinstance(dtype, torch.dtype):
raise ValueError(
"Expected `dtype` to be a `torch.dtype`, but got a "
f"{type(dtype).__name__} instead."
)

if isinstance(data, dict):
images = next(iter(data.values()))
else:
images = data

inputs = [
torch.as_tensor(image, dtype=dtype).to(self.device) for image in images
]
outputs = self.call_model(inputs)
outputs = _convert_outputs_to_ndarray_batch(outputs)
outputs = {"pred_" + key: value for key, value in outputs.items()}

return outputs


def _convert_outputs_to_ndarray_batch(
outputs: List[Dict[str, torch.Tensor]],
) -> Dict[str, np.ndarray]:
"""Batch detection model outputs.

TorchVision detection models return `List[Dict[Tensor]]`. Each `Dict` contain
'boxes', 'labels, and 'scores'.

>>> import torch
>>> from torchvision import models
>>> model = models.detection.fasterrcnn_resnet50_fpn_v2()
>>> model.eval() # doctest: +ELLIPSIS
FasterRCNN(...)
>>> outputs = model(torch.zeros((2, 3, 32, 32)))
>>> len(outputs)
2
>>> outputs[0].keys()
dict_keys(['boxes', 'labels', 'scores'])

This function batches values and returns a `Dict[str, np.ndarray]`.

>>> from ray.train.torch.torch_detection_predictor import _convert_outputs_to_ndarray_batch
>>> batch = _convert_outputs_to_ndarray_batch(outputs)
>>> batch.keys()
dict_keys(['boxes', 'labels', 'scores'])
>>> batch["boxes"].shape
(2,)
""" # noqa: E501
batch = collections.defaultdict(list)
for output in outputs:
for key, value in output.items():
batch[key].append(value.cpu().detach().numpy())
for key, value in batch.items():
batch[key] = create_ragged_ndarray(value)
return batch
24 changes: 15 additions & 9 deletions python/ray/train/torch/torch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import numpy as np
import torch

from ray.util import log_once
from ray.train.predictor import DataBatchType
from ray.air.checkpoint import Checkpoint
from ray.air._internal.torch_utils import convert_ndarray_batch_to_torch_tensor_batch
from ray.train.torch.torch_checkpoint import TorchCheckpoint
from ray.air.checkpoint import Checkpoint
from ray.train._internal.dl_predictor import DLPredictor
from ray.train.predictor import DataBatchType
from ray.train.torch.torch_checkpoint import TorchCheckpoint
from ray.util import log_once
from ray.util.annotations import DeveloperAPI, PublicAPI

if TYPE_CHECKING:
Expand Down Expand Up @@ -38,12 +38,16 @@ def __init__(
):
self.model = model
self.model.eval()

# TODO (jiaodong): #26249 Use multiple GPU devices with sharded input
self.use_gpu = use_gpu

if use_gpu:
# Ensure input tensor and model live on GPU for GPU inference
self.model.to(torch.device("cuda"))
# TODO (jiaodong): #26249 Use multiple GPU devices with sharded input
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")

# Ensure input tensor and model live on the same device
self.model.to(self.device)

if (
not use_gpu
Expand Down Expand Up @@ -117,6 +121,8 @@ def call_model(

.. testcode::

from ray.train.torch import TorchPredictor

# List outputs are not supported by default TorchPredictor.
# So let's define a custom TorchPredictor and override call_model
class MyModel(torch.nn.Module):
Expand Down Expand Up @@ -231,7 +237,7 @@ def _arrays_to_tensors(
return convert_ndarray_batch_to_torch_tensor_batch(
numpy_arrays,
dtypes=dtype,
device="cuda" if self.use_gpu else None,
device=self.device,
)

def _tensor_to_array(self, tensor: torch.Tensor) -> np.ndarray:
Expand Down