-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[AIR] Add TorchDetectionPredictor
#32199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
3203f0c
Initial commit
bveeramani 28fa37c
Address review comments
bveeramani 2f5fd21
Update utils.py
bveeramani 679f3b1
Merge remote-tracking branch 'upstream/master' into promote-create-ra…
bveeramani e3a7630
Merge remote-tracking branch 'upstream/master' into promote-create-ra…
bveeramani 10d78cf
Address review comments
bveeramani 0d331c5
Update tensor_extension.py
bveeramani 328b3f8
Update tensor_extension.py
bveeramani 694764e
Merge remote-tracking branch 'upstream/master' into promote-create-ra…
bveeramani 9855e94
Initial commit
bveeramani bbbda1f
Merge branch 'dtype-optional' into detection-predictor
bveeramani 22b3aa9
Add `TorchDetectionPredictor`
bveeramani ab67b37
Fix test
bveeramani 750cc0b
Address review comments
bveeramani edf14b8
Merge branch 'dtype-optional' into detection-predictor
bveeramani 43693e5
Merge branch 'promote-create-ragged' into detection-predictor
bveeramani dd3a9eb
Address review comments
bveeramani 6ecf830
Update torch_detection_predictor.py
bveeramani 6374151
Fix Bazel
bveeramani c85f5d1
Merge remote-tracking branch 'upstream/master' into detection-predictor
bveeramani 2f41527
Merge remote-tracking branch 'upstream/master' into detection-predictor
bveeramani bdf28f9
Update BUILD
bveeramani 163d2af
Update torch_predictor.py
bveeramani a55bc0a
Merge remote-tracking branch 'upstream/master' into detection-predictor
bveeramani cea988f
Update torch_predictor.py
bveeramani File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
python/ray/train/tests/test_torch_detection_predictor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import numpy as np | ||
import pytest | ||
from torchvision import models | ||
|
||
import ray | ||
from ray.air.util.tensor_extensions.utils import create_ragged_ndarray | ||
from ray.train.batch_predictor import BatchPredictor | ||
from ray.train.torch import TorchCheckpoint, TorchDetectionPredictor | ||
|
||
|
||
@pytest.fixture(name="predictor") | ||
def predictor_fixture(): | ||
model = models.detection.maskrcnn_resnet50_fpn() | ||
yield TorchDetectionPredictor(model=model) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data", | ||
[ | ||
np.zeros((1, 3, 32, 32), dtype=np.float32), | ||
{"image": np.zeros((1, 3, 32, 32), dtype=np.float32)}, | ||
create_ragged_ndarray( | ||
[ | ||
np.zeros((3, 32, 32), dtype=np.float32), | ||
np.zeros((3, 64, 64), dtype=np.float32), | ||
] | ||
), | ||
], | ||
) | ||
def test_predict(predictor, data): | ||
predictions = predictor.predict(data) | ||
|
||
assert all(len(value) == len(data) for value in predictions.values()) | ||
# Boxes should have shape `(# detections, 4)`. | ||
assert all(boxes.ndim == 2 for boxes in predictions["pred_boxes"]) | ||
assert all(boxes.shape[-1] == 4 for boxes in predictions["pred_boxes"]) | ||
# Labels should have shape `(# detections,)`. | ||
assert all(labels.ndim == 1 for labels in predictions["pred_labels"]) | ||
# Scores should have shape `(# detections,)`. | ||
assert all(scores.ndim == 1 for scores in predictions["pred_scores"]) | ||
|
||
|
||
def test_predict_tensor_dataset(): | ||
model = models.detection.maskrcnn_resnet50_fpn() | ||
checkpoint = TorchCheckpoint.from_model(model) | ||
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchDetectionPredictor) | ||
dataset = ray.data.from_items([np.zeros((3, 32, 32), dtype=np.float32)]) | ||
|
||
predictions = predictor.predict(dataset) | ||
|
||
# Boxes should have shape `(# detections, 4)`. | ||
pred_boxes = [row["pred_boxes"] for row in predictions.take_all()] | ||
assert all(boxes.ndim == 2 for boxes in pred_boxes) | ||
assert all(boxes.shape[-1] == 4 for boxes in pred_boxes) | ||
# Labels should have shape `(# detections,)`. | ||
pred_labels = [row["pred_labels"] for row in predictions.take_all()] | ||
assert all(labels.ndim == 1 for labels in pred_labels) | ||
# Scores should have shape `(# detections,)`. | ||
pred_scores = [row["pred_scores"] for row in predictions.take_all()] | ||
assert all(scores.ndim == 1 for scores in pred_scores) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"items", | ||
[ | ||
[{"image": np.zeros((3, 32, 32), dtype=np.float32)}], | ||
[ | ||
{"image": np.zeros((3, 32, 32), dtype=np.float32)}, | ||
{"image": np.zeros((3, 64, 64), dtype=np.float32)}, | ||
], | ||
], | ||
) | ||
def test_predict_tabular_dataset(items): | ||
model = models.detection.maskrcnn_resnet50_fpn() | ||
checkpoint = TorchCheckpoint.from_model(model) | ||
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchDetectionPredictor) | ||
dataset = ray.data.from_items(items) | ||
|
||
predictions = predictor.predict(dataset) | ||
|
||
assert predictions.count() == len(items) | ||
# Boxes should have shape `(# detections, 4)`. | ||
pred_boxes = [row["pred_boxes"] for row in predictions.take_all()] | ||
assert all(boxes.ndim == 2 for boxes in pred_boxes) | ||
assert all(boxes.shape[-1] == 4 for boxes in pred_boxes) | ||
# Labels should have shape `(# detections,)`. | ||
pred_labels = [row["pred_labels"] for row in predictions.take_all()] | ||
assert all(labels.ndim == 1 for labels in pred_labels) | ||
# Scores should have shape `(# detections,)`. | ||
pred_scores = [row["pred_scores"] for row in predictions.take_all()] | ||
assert all(scores.ndim == 1 for scores in pred_scores) | ||
|
||
|
||
def test_multi_column_batch_raises_value_error(predictor): | ||
data = { | ||
"image": np.zeros((2, 3, 32, 32), dtype=np.float32), | ||
"boxes": np.zeros((2, 0, 4), dtype=np.float32), | ||
"labels": np.zeros((2, 0), dtype=np.int64), | ||
} | ||
with pytest.raises(ValueError): | ||
# `data` should only contain one key. Otherwise, `TorchDetectionPredictor` | ||
# doesn't know which column contains the input images. | ||
predictor.predict(data) | ||
|
||
|
||
def test_invalid_dtype_raises_value_error(predictor): | ||
data = np.zeros((1, 3, 32, 32), dtype=np.float32) | ||
with pytest.raises(ValueError): | ||
# `dtype` should be a single `torch.dtype`. | ||
predictor.predict(data, dtype=np.float32) | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
|
||
import pytest | ||
|
||
sys.exit(pytest.main(["-v", "-x", __file__])) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import collections | ||
from typing import Dict, List, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from ray.air.util.tensor_extensions.utils import create_ragged_ndarray | ||
from ray.train._internal.dl_predictor import TensorDtype | ||
from ray.train.torch.torch_predictor import TorchPredictor | ||
from ray.util.annotations import PublicAPI | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class TorchDetectionPredictor(TorchPredictor): | ||
"""A predictor for TorchVision detection models. | ||
|
||
Unlike other Torch models, instance segmentation models return | ||
`List[Dict[str, Tensor]]`. This predictor extends :class:`TorchPredictor` to support | ||
the non-standard outputs. | ||
|
||
To learn more about instance segmentation models, read | ||
`Instance segmentation models <https://pytorch.org/vision/main/auto_examples/plot_visualization_utils.html#instance-seg-output>`_. | ||
|
||
Example: | ||
|
||
.. testcode:: | ||
|
||
import numpy as np | ||
from torchvision import models | ||
|
||
from ray.train.torch import TorchDetectionPredictor | ||
|
||
model = models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=True) | ||
|
||
predictor = TorchDetectionPredictor(model=model) | ||
predictions = predictor.predict(np.zeros((4, 3, 32, 32), dtype=np.float32)) | ||
|
||
print(predictions.keys()) | ||
|
||
.. testoutput:: | ||
|
||
dict_keys(['pred_boxes', 'pred_labels', 'pred_scores']) | ||
|
||
.. testcode:: | ||
|
||
import numpy as np | ||
from torchvision import models | ||
|
||
import ray | ||
from ray.train.batch_predictor import BatchPredictor | ||
from ray.train.torch import TorchCheckpoint, TorchDetectionPredictor | ||
|
||
dataset = ray.data.from_items([{"image": np.zeros((3, 32, 32), dtype=np.float32)}]) | ||
model = models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=True) | ||
checkpoint = TorchCheckpoint.from_model(model) | ||
bveeramani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
predictor = BatchPredictor.from_checkpoint(checkpoint, TorchDetectionPredictor) | ||
predictions = predictor.predict(dataset, feature_columns=["image"]) | ||
|
||
print(predictions.take(1)) | ||
|
||
.. testoutput:: | ||
|
||
[{'pred_boxes': array([], shape=(0, 4), dtype=float32), 'pred_labels': array([], dtype=int64), 'pred_scores': array([], dtype=float32)}] | ||
""" # noqa: E501 | ||
|
||
def _predict_numpy( | ||
self, | ||
data: Union[np.ndarray, Dict[str, np.ndarray]], | ||
dtype: Optional[Union[TensorDtype, Dict[str, TensorDtype]]], | ||
) -> Dict[str, np.ndarray]: | ||
if isinstance(data, dict) and len(data) != 1: | ||
raise ValueError( | ||
f"""Expected input to contain one key, but got {len(data)} instead. | ||
|
||
If you're using `BatchPredictor`, pass a one-element list to | ||
`feature_columns`. | ||
|
||
--- | ||
predictor = BatchPredictor(checkpoint, TorchDetectionPredictor) | ||
predictor.predict(dataset, feature_columns=["image"]) | ||
--- | ||
""" | ||
) | ||
|
||
if dtype is not None and not isinstance(dtype, torch.dtype): | ||
raise ValueError( | ||
"Expected `dtype` to be a `torch.dtype`, but got a " | ||
f"{type(dtype).__name__} instead." | ||
) | ||
|
||
if isinstance(data, dict): | ||
images = next(iter(data.values())) | ||
else: | ||
images = data | ||
|
||
inputs = [ | ||
torch.as_tensor(image, dtype=dtype).to(self.device) for image in images | ||
] | ||
outputs = self.call_model(inputs) | ||
outputs = _convert_outputs_to_ndarray_batch(outputs) | ||
outputs = {"pred_" + key: value for key, value in outputs.items()} | ||
|
||
return outputs | ||
|
||
|
||
def _convert_outputs_to_ndarray_batch( | ||
outputs: List[Dict[str, torch.Tensor]], | ||
) -> Dict[str, np.ndarray]: | ||
"""Batch detection model outputs. | ||
|
||
TorchVision detection models return `List[Dict[Tensor]]`. Each `Dict` contain | ||
'boxes', 'labels, and 'scores'. | ||
|
||
>>> import torch | ||
>>> from torchvision import models | ||
>>> model = models.detection.fasterrcnn_resnet50_fpn_v2() | ||
>>> model.eval() # doctest: +ELLIPSIS | ||
FasterRCNN(...) | ||
>>> outputs = model(torch.zeros((2, 3, 32, 32))) | ||
>>> len(outputs) | ||
2 | ||
>>> outputs[0].keys() | ||
dict_keys(['boxes', 'labels', 'scores']) | ||
|
||
This function batches values and returns a `Dict[str, np.ndarray]`. | ||
|
||
>>> from ray.train.torch.torch_detection_predictor import _convert_outputs_to_ndarray_batch | ||
>>> batch = _convert_outputs_to_ndarray_batch(outputs) | ||
>>> batch.keys() | ||
dict_keys(['boxes', 'labels', 'scores']) | ||
>>> batch["boxes"].shape | ||
(2,) | ||
""" # noqa: E501 | ||
batch = collections.defaultdict(list) | ||
for output in outputs: | ||
for key, value in output.items(): | ||
batch[key].append(value.cpu().detach().numpy()) | ||
for key, value in batch.items(): | ||
batch[key] = create_ragged_ndarray(value) | ||
return batch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.