-
Notifications
You must be signed in to change notification settings - Fork 157
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add benchmark from torchvision training references (#714)
Summary: Towards #416 This is a modified and simplified version of the torchvision classification training reference that provides: - Distributed Learning (DDP) vs 1-GPU training - Datapipes (with DataLoader or torchdata.dataloader2) vs Iterable datasets (non-DP) vs MapStyle Datasets - Full training procedure or Data-loading only (with or without transforms) or Model training only (generating fake datasets) - Timing of data-loading vs model training - any classification model from torchvision I removed a lot of non-essential features from the original reference, but I can simplify further. Typically I would expect the `MetricLogger` to disappear, or be trimmed down to its most essential bits. Pull Request resolved: #714 Reviewed By: NivekT Differential Revision: D38569273 Pulled By: NicolasHug fbshipit-source-id: 1bc4442ab826256123f8360c14dc8b3eccd73256
- Loading branch information
1 parent
5dade9a
commit 3a348a7
Showing
4 changed files
with
830 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import itertools | ||
import os | ||
import random | ||
from functools import partial | ||
from pathlib import Path | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torchvision | ||
from PIL import Image | ||
from torchdata.datapipes.iter import FileLister, IterDataPipe | ||
|
||
|
||
# TODO: maybe infinite buffer can / is already natively supported by torchdata? | ||
INFINITE_BUFFER_SIZE = 1_000_000_000 | ||
|
||
IMAGENET_TRAIN_LEN = 1_281_167 | ||
IMAGENET_TEST_LEN = 50_000 | ||
|
||
|
||
class _LenSetter(IterDataPipe): | ||
# TODO: Ideally, we woudn't need this extra class | ||
def __init__(self, dp, root): | ||
self.dp = dp | ||
|
||
if "train" in str(root): | ||
self.size = IMAGENET_TRAIN_LEN | ||
elif "val" in str(root): | ||
self.size = IMAGENET_TEST_LEN | ||
else: | ||
raise ValueError("oops?") | ||
|
||
def __iter__(self): | ||
yield from self.dp | ||
|
||
def __len__(self): | ||
# TODO The // world_size part shouldn't be needed. See https://github.com/pytorch/data/issues/533 | ||
return self.size // dist.get_world_size() | ||
|
||
|
||
def _decode(path, root, category_to_int): | ||
category = Path(path).relative_to(root).parts[0] | ||
|
||
image = Image.open(path).convert("RGB") | ||
label = category_to_int(category) | ||
|
||
return image, label | ||
|
||
|
||
def _apply_tranforms(img_and_label, transforms): | ||
img, label = img_and_label | ||
return transforms(img), label | ||
|
||
|
||
def make_dp(root, transforms): | ||
|
||
root = Path(root).expanduser().resolve() | ||
categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) | ||
category_to_int = {category: i for (i, category) in enumerate(categories)} | ||
|
||
dp = FileLister(str(root), recursive=True, masks=["*.JPEG"]) | ||
|
||
dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False).sharding_filter() | ||
dp = dp.map(partial(_decode, root=root, category_to_int=category_to_int)) | ||
dp = dp.map(partial(_apply_tranforms, transforms=transforms)) | ||
|
||
dp = _LenSetter(dp, root=root) | ||
return dp | ||
|
||
|
||
class PreLoadedMapStyle: | ||
# All the data is pre-loaded and transformed in __init__, so the DataLoader should be crazy fast. | ||
# This is just to assess how fast a model could theoretically be trained if there was no data bottleneck at all. | ||
def __init__(self, dir, transform, buffer_size=100): | ||
dataset = torchvision.datasets.ImageFolder(dir, transform=transform) | ||
self.size = len(dataset) | ||
self.samples = [dataset[torch.randint(0, len(dataset), size=(1,)).item()] for i in range(buffer_size)] | ||
|
||
def __len__(self): | ||
return self.size | ||
|
||
def __getitem__(self, idx): | ||
return self.samples[idx % len(self.samples)] | ||
|
||
|
||
class _PreLoadedDP(IterDataPipe): | ||
# Same as above, but this is a DataPipe | ||
def __init__(self, root, transforms, buffer_size=100): | ||
dataset = torchvision.datasets.ImageFolder(root, transform=transforms) | ||
self.size = len(dataset) | ||
self.samples = [dataset[torch.randint(0, len(dataset), size=(1,)).item()] for i in range(buffer_size)] | ||
# Note: the rng might be different across DDP workers so they'll all have different samples. | ||
# But we don't care about accuracy here so whatever. | ||
|
||
def __iter__(self): | ||
for idx in range(self.size): | ||
yield self.samples[idx % len(self.samples)] | ||
|
||
|
||
def make_pre_loaded_dp(root, transforms): | ||
dp = _PreLoadedDP(root=root, transforms=transforms) | ||
dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False).sharding_filter() | ||
dp = _LenSetter(dp, root=root) | ||
return dp | ||
|
||
|
||
class MapStyleToIterable(torch.utils.data.IterableDataset): | ||
# This converts a MapStyle dataset into an iterable one. | ||
# Not sure this kind of Iterable dataset is actually useful to benchmark. It | ||
# was necessary when benchmarking async-io stuff, but not anymore. | ||
# If anything, it shows how tricky Iterable datasets are to implement. | ||
def __init__(self, dataset, shuffle): | ||
self.dataset = dataset | ||
self.shuffle = shuffle | ||
|
||
self.size = len(self.dataset) | ||
self.seed = 0 # has to be hard-coded for all DDP workers to have the same shuffling | ||
|
||
def __len__(self): | ||
return self.size // dist.get_world_size() | ||
|
||
def __iter__(self): | ||
|
||
worker_info = torch.utils.data.get_worker_info() | ||
num_dl_workers = worker_info.num_workers | ||
dl_worker_id = worker_info.id | ||
|
||
num_ddp_workers = dist.get_world_size() | ||
ddp_worker_id = dist.get_rank() | ||
|
||
num_total_workers = num_ddp_workers * num_dl_workers | ||
current_worker_id = ddp_worker_id + (num_ddp_workers * dl_worker_id) | ||
|
||
indices = range(self.size) | ||
if self.shuffle: | ||
rng = random.Random(self.seed) | ||
indices = rng.sample(indices, k=self.size) | ||
indices = itertools.islice(indices, current_worker_id, None, num_total_workers) | ||
|
||
samples = (self.dataset[i] for i in indices) | ||
yield from samples | ||
|
||
|
||
# TODO: maybe only generate these when --no-transforms is passed? | ||
_RANDOM_IMAGE_TENSORS = [torch.randn(3, 224, 224) for _ in range(300)] | ||
|
||
|
||
def no_transforms(_): | ||
# see --no-transforms doc | ||
return random.choice(_RANDOM_IMAGE_TENSORS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
from torchvision.transforms import transforms | ||
|
||
|
||
class ClassificationPresetTrain: | ||
def __init__( | ||
self, | ||
*, | ||
crop_size, | ||
mean=(0.485, 0.456, 0.406), | ||
std=(0.229, 0.224, 0.225), | ||
hflip_prob=0.5, | ||
): | ||
trans = [transforms.RandomResizedCrop(crop_size)] | ||
if hflip_prob > 0: | ||
trans.append(transforms.RandomHorizontalFlip(hflip_prob)) | ||
|
||
trans.extend( | ||
[ | ||
transforms.PILToTensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
transforms.Normalize(mean=mean, std=std), | ||
] | ||
) | ||
|
||
self.transforms = transforms.Compose(trans) | ||
|
||
def __call__(self, img): | ||
return self.transforms(img) | ||
|
||
|
||
class ClassificationPresetEval: | ||
def __init__( | ||
self, | ||
*, | ||
crop_size, | ||
resize_size=256, | ||
mean=(0.485, 0.456, 0.406), | ||
std=(0.229, 0.224, 0.225), | ||
): | ||
|
||
self.transforms = transforms.Compose( | ||
[ | ||
transforms.Resize(resize_size), | ||
transforms.CenterCrop(crop_size), | ||
transforms.PILToTensor(), | ||
transforms.ConvertImageDtype(torch.float), | ||
transforms.Normalize(mean=mean, std=std), | ||
] | ||
) | ||
|
||
def __call__(self, img): | ||
return self.transforms(img) |
Oops, something went wrong.