Skip to content

Commit

Permalink
[AIR] Add TorchVisionPreprocessor (ray-project#30578)
Browse files Browse the repository at this point in the history
Co-authored-by: Clark Zinzow <clarkzinzow@gmail.com>
Closes ray-project#30403

Signed-off-by: tmynn <hovhannes.tamoyan@gmail.com>
  • Loading branch information
bveeramani authored and tamohannes committed Jan 25, 2023
1 parent 887eb88 commit c488327
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 88 deletions.
33 changes: 10 additions & 23 deletions doc/source/ray-air/examples/torch_image_batch_pretrained.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
from typing import Dict

import numpy as np

import torch
from torchvision import transforms
from torchvision.models import resnet18

import ray
from ray.train.torch import TorchCheckpoint, TorchPredictor
from ray.train.batch_predictor import BatchPredictor
from ray.data.preprocessors import BatchMapper


def preprocess(image_batch: Dict[str, np.ndarray]) -> np.ndarray:
"""
User PyTorch code to transform user image with outer dimension of batch size.
"""
preprocess = transforms.Compose(
[
# Torchvision's ToTensor does not accept outer batch dimension
transforms.CenterCrop(224),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Outer dimension is batch size such as (10, 256, 256, 3) -> (10, 3, 256, 256)
transposed_torch_tensor = torch.Tensor(image_batch["image"].transpose(0, 3, 1, 2))
return preprocess(transposed_torch_tensor).numpy()
from ray.data.preprocessors import TorchVisionPreprocessor


data_url = "s3://anonymous@air-example-data-2/1G-image-data-synthetic-raw"
Expand All @@ -34,7 +13,15 @@ def preprocess(image_batch: Dict[str, np.ndarray]) -> np.ndarray:

model = resnet18(pretrained=True)

preprocessor = BatchMapper(preprocess, batch_format="numpy")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.CenterCrop(224),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
preprocessor = TorchVisionPreprocessor(columns=["image"], transform=transform)

ckpt = TorchCheckpoint.from_model(model=model, preprocessor=preprocessor)

predictor = BatchPredictor.from_checkpoint(ckpt, TorchPredictor)
Expand Down
38 changes: 30 additions & 8 deletions doc/source/ray-air/examples/torch_image_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,8 @@
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"\n",
"transform = transforms.Compose(\n",
" [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n",
")\n",
"\n",
"train_dataset = torchvision.datasets.CIFAR10(\"data\", download=True, train=True, transform=transform)\n",
"test_dataset = torchvision.datasets.CIFAR10(\"data\", download=True, train=False, transform=transform)\n",
"train_dataset = torchvision.datasets.CIFAR10(\"data\", download=True, train=True)\n",
"test_dataset = torchvision.datasets.CIFAR10(\"data\", download=True, train=False)\n",
"\n",
"train_dataset: ray.data.Dataset = ray.data.from_torch(train_dataset)\n",
"test_dataset: ray.data.Dataset = ray.data.from_torch(test_dataset)"
Expand Down Expand Up @@ -189,11 +185,12 @@
"source": [
"from typing import Dict, Tuple\n",
"import numpy as np\n",
"from PIL.Image import Image\n",
"import torch\n",
"\n",
"\n",
"def convert_batch_to_numpy(batch: Tuple[torch.Tensor, int]) -> Dict[str, np.ndarray]:\n",
" images = np.array([image.numpy() for image, _ in batch])\n",
"def convert_batch_to_numpy(batch: Tuple[Image, int]) -> Dict[str, np.ndarray]:\n",
" images = np.stack([np.array(image) for image, _ in batch])\n",
" labels = np.array([label for _, label in batch])\n",
" return {\"image\": images, \"label\": labels}\n",
"\n",
Expand Down Expand Up @@ -332,6 +329,30 @@
" session.report(metrics, checkpoint=checkpoint)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "76f83b27",
"metadata": {},
"source": [
"To improve our model's accuracy, we'll also define a `Preprocessor` to normalize the images."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f25ced31",
"metadata": {},
"outputs": [],
"source": [
"from ray.data.preprocessors import TorchVisionPreprocessor\n",
"\n",
"transform = transforms.Compose(\n",
" [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n",
")\n",
"preprocessor = TorchVisionPreprocessor(columns=[\"image\"], transform=transform)"
]
},
{
"cell_type": "markdown",
"id": "58100f87",
Expand Down Expand Up @@ -488,6 +509,7 @@
" train_loop_config={\"batch_size\": 2},\n",
" datasets={\"train\": train_dataset},\n",
" scaling_config=ScalingConfig(num_workers=2),\n",
" preprocessor=preprocessor\n",
")\n",
"result = trainer.fit()\n",
"latest_checkpoint = result.checkpoint"
Expand Down
24 changes: 9 additions & 15 deletions doc/source/ray-air/examples/torch_incremental_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -515,20 +515,14 @@
"import torch\n",
"from torchvision import transforms\n",
"\n",
"from ray.data.preprocessors import BatchMapper\n",
"from ray.data.preprocessors import TorchVisionPreprocessor\n",
"\n",
"def preprocess_images(image_batch_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:\n",
" \"\"\"Preprocess images by scaling each channel in the image.\"\"\"\n",
" torchvision_transforms = transforms.Compose(\n",
" [transforms.Normalize((0.1307,), (0.3081,))]\n",
" )\n",
" # Outer dimension is batch size such as (4096, 28, 28)\n",
" image_batch_dict[\"image\"] = torchvision_transforms(\n",
" torch.Tensor(image_batch_dict[\"image\"])\n",
" ).numpy()\n",
" return image_batch_dict\n",
"\n",
"mnist_normalize_preprocessor = BatchMapper(fn=preprocess_images, batch_format=\"numpy\")"
"transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
"])\n",
"mnist_normalize_preprocessor = TorchVisionPreprocessor(columns=[\"image\"], transform=transform)"
]
},
{
Expand Down Expand Up @@ -1812,7 +1806,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand All @@ -1826,11 +1820,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.8 (main, Oct 13 2022, 09:48:40) [Clang 14.0.0 (clang-1400.0.29.102)]"
},
"vscode": {
"interpreter": {
"hash": "99d89bfe98f3aa2d7facda0d08d31ff2a0af9559e5330d719288ce64a1966273"
"hash": "c704e19737f24b51bc631dadcac7a7e356bb35d1c5cd7766248d8a6946059909"
}
}
},
Expand Down
6 changes: 6 additions & 0 deletions doc/source/ray-air/package-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ K-Bins Discretizers
.. autoclass:: ray.data.preprocessors.UniformKBinsDiscretizer
:show-inheritance:

Image Preprocessors
###################

.. autoclass:: ray.data.preprocessors.TorchVisionPreprocessor
:show-inheritance:

Text Encoders
#############

Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/preprocessors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from ray.data.preprocessors.concatenator import Concatenator
from ray.data.preprocessors.tokenizer import Tokenizer
from ray.data.preprocessors.torch import TorchVisionPreprocessor
from ray.data.preprocessors.transformer import PowerTransformer
from ray.data.preprocessors.vectorizer import CountVectorizer, HashingVectorizer
from ray.data.preprocessors.discretizer import (
Expand Down Expand Up @@ -45,6 +46,7 @@
"StandardScaler",
"Concatenator",
"Tokenizer",
"TorchVisionPreprocessor",
"CustomKBinsDiscretizer",
"UniformKBinsDiscretizer",
]
95 changes: 95 additions & 0 deletions python/ray/data/preprocessors/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import TYPE_CHECKING, Callable, Dict, List, Union

import numpy as np

from ray.data.preprocessor import Preprocessor
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
import torch


@PublicAPI(stability="alpha")
class TorchVisionPreprocessor(Preprocessor):
"""Apply a `TorchVision transform <https://pytorch.org/vision/stable/transforms.html>`_
to image columns.
Examples:
>>> import ray
>>> dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images")
>>> dataset # doctest: +ellipsis
Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(..., 3), dtype=float)})
:class:`TorchVisionPreprocessor` passes ndarrays to your transform. To convert
ndarrays to Torch tensors, add ``ToTensor`` to your pipeline.
>>> from torchvision import transforms
>>> from ray.data.preprocessors import TorchVisionPreprocessor
>>> transform = transforms.Compose([
... transforms.ToTensor(),
... transforms.Resize((224, 224)),
... ])
>>> preprocessor = TorchVisionPreprocessor(["image"], transform=transform)
>>> preprocessor.transform(dataset) # doctest: +ellipsis
Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(3, 224, 224), dtype=float)})
For better performance, set ``batched`` to ``True`` and replace ``ToTensor``
with a batch-supporting ``Lambda``.
>>> transform = transforms.Compose([
... transforms.Lambda(
... lambda batch: torch.as_tensor(batch).permute(0, 3, 1, 2))
... ),
... transforms.Resize((224, 224))
... ])
>>> preprocessor = TorchVisionPreprocessor(
... ["image"], transform=transform, batched=True
... )
>>> preprocessor.transform(dataset) # doctest: +ellipsis
Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(3, 224, 224), dtype=float)})
Args:
columns: The columns to apply the TorchVision transform to.
transform: The TorchVision transform you want to apply. This transform should
accept an ``np.ndarray`` as input and return a ``torch.Tensor`` as output.
batched: If ``True``, apply ``transform`` to batches of shape
:math:`(B, H, W, C)`. Otherwise, apply ``transform`` to individual images.
""" # noqa: E501

_is_fittable = False

def __init__(
self,
columns: List[str],
transform: Callable[["np.ndarray"], "torch.Tensor"],
batched: bool = False,
):
self._columns = columns
self._fn = transform
self._batched = batched

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(columns={self._columns}, "
f"transform={self._fn!r})"
)

def _transform_numpy(
self, np_data: Union["np.ndarray", Dict[str, "np.ndarray"]]
) -> Union["np.ndarray", Dict[str, "np.ndarray"]]:
def transform(batch: np.ndarray) -> np.ndarray:
if self._batched:
return self._fn(batch).numpy()
return np.array([self._fn(array).numpy() for array in batch])

if isinstance(np_data, dict):
outputs = {}
for column, batch in np_data.items():
if column in self._columns:
outputs[column] = transform(batch)
else:
outputs[column] = batch
else:
outputs = transform(np_data)

return outputs
Loading

0 comments on commit c488327

Please sign in to comment.