Skip to content

Commit

Permalink
Add Dali MNIST example (#3721)
Browse files Browse the repository at this point in the history
* add MNIST DALI example, update README.md

* Fix PEP8 warnings

* reformatted using black

* add mnist_dali to test_examples.py

* Add documentation as docstrings

* add nvidia-pyindex and nvidia-dali-cuda100

* replace nvidia-pyindex with --extra-index-url

* mark mnist_dali test as Linux and GPU only

* adjust CUDA docker and examples.txt, fix import error in test_examples.py

* adjust the GPU check

* Exit when DALI is not available

* remove requirements-examples.txt and DALI pip install

* Refactored example, moved to new logging api, added runtime check for test and dali script

* Patch to reflect the mnist example module

* add req.

* Apply suggestions from code review

* Removed requirement as it breaks CPU install, added note in README to install DALI

* add DALI to Drone

* test examples

* Apply suggestions from code review

* imports

* ABC

* cuda

* cuda

* pip DALI

* Move build into init function

Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
  • Loading branch information
5 people authored Nov 6, 2020
1 parent f3dfb98 commit 6e5f232
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 10 deletions.
2 changes: 2 additions & 0 deletions .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ steps:
- pip --version
- nvidia-smi
- pip install -r ./requirements/devel.txt --upgrade-strategy only-if-needed -v --no-cache-dir
# when Image has defined CUDa version we can switch to this package spec "nvidia-dali-cuda${CUDA_VERSION%%.*}0"
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100 --upgrade-strategy only-if-needed
- pip list
- coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --color=yes --durations=25 # --flake8
- python -m pytest benchmarks pl_examples -v --color=yes --maxfail=2 --durations=0 # --flake8
Expand Down
10 changes: 9 additions & 1 deletion pl_examples/basic_examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ python mnist.py
python mnist.py --gpus 2 --distributed_backend 'dp'
```

---
---
#### MNIST with DALI
The MNIST example above using [NVIDIA DALI](https://developer.nvidia.com/DALI).
Requires NVIDIA DALI to be installed based on your CUDA version, see [here](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html).
```bash
python mnist_dali.py
```

---
#### Image classifier
Generic image classifier with an arbitrary backbone (ie: a simple system)
```bash
Expand Down
204 changes: 204 additions & 0 deletions pl_examples/basic_examples/mnist_dali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from argparse import ArgumentParser
from random import shuffle
from warnings import warn

import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import random_split

import pytorch_lightning as pl

try:
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
except Exception:
from tests.base.datasets import MNIST

try:
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
except (ImportError, ModuleNotFoundError):
warn('NVIDIA DALI is not available')
ops, types, Pipeline, DALIClassificationIterator = ..., ..., ABC, ABC


class ExternalMNISTInputIterator(object):
"""
This iterator class wraps torchvision's MNIST dataset and returns the images and labels in batches
"""

def __init__(self, mnist_ds, batch_size):
self.batch_size = batch_size
self.mnist_ds = mnist_ds
self.indices = list(range(len(self.mnist_ds)))
shuffle(self.indices)

def __iter__(self):
self.i = 0
self.n = len(self.mnist_ds)
return self

def __next__(self):
batch = []
labels = []
for _ in range(self.batch_size):
index = self.indices[self.i]
img, label = self.mnist_ds[index]
batch.append(img.numpy())
labels.append(np.array([label], dtype=np.uint8))
self.i = (self.i + 1) % self.n
return (batch, labels)


class ExternalSourcePipeline(Pipeline):
"""
This DALI pipeline class just contains the MNIST iterator
"""

def __init__(self, batch_size, eii, num_threads, device_id):
super(ExternalSourcePipeline, self).__init__(batch_size, num_threads, device_id, seed=12)
self.source = ops.ExternalSource(source=eii, num_outputs=2)
self.build()

def define_graph(self):
images, labels = self.source()
return images, labels


class DALIClassificationLoader(DALIClassificationIterator):
"""
This class extends DALI's original DALIClassificationIterator with the __len__() function so that we can call len() on it
"""

def __init__(
self,
pipelines,
size=-1,
reader_name=None,
auto_reset=False,
fill_last_batch=True,
dynamic_shape=False,
last_batch_padded=False,
):
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded)

def __len__(self):
batch_count = self._size // (self._num_gpus * self.batch_size)
last_batch = 1 if self._fill_last_batch else 0
return batch_count + last_batch


class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()

self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x

def split_batch(self, batch):
return batch[0]["data"], batch[0]["label"].squeeze().long()

def training_step(self, batch, batch_idx):
x, y = self.split_batch(batch)
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss

def validation_step(self, batch, batch_idx):
x, y = self.split_batch(batch)
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)

def test_step(self, batch, batch_idx):
x, y = self.split_batch(batch)
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('test_loss', loss)

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser


def cli_main():
pl.seed_everything(1234)

# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
args = parser.parse_args()

# ------------
# data
# ------------
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

eii_train = ExternalMNISTInputIterator(mnist_train, args.batch_size)
eii_val = ExternalMNISTInputIterator(mnist_val, args.batch_size)
eii_test = ExternalMNISTInputIterator(mnist_test, args.batch_size)

pipe_train = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_train, num_threads=2, device_id=0)
train_loader = DALIClassificationLoader(pipe_train, size=len(mnist_train), auto_reset=True, fill_last_batch=False)

pipe_val = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_val, num_threads=2, device_id=0)
val_loader = DALIClassificationLoader(pipe_val, size=len(mnist_val), auto_reset=True, fill_last_batch=False)

pipe_test = ExternalSourcePipeline(batch_size=args.batch_size, eii=eii_test, num_threads=2, device_id=0)
test_loader = DALIClassificationLoader(pipe_test, size=len(mnist_test), auto_reset=True, fill_last_batch=False)

# ------------
# model
# ------------
model = LitClassifier(args.hidden_dim, args.learning_rate)

# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)

# ------------
# testing
# ------------
trainer.test(test_dataloaders=test_loader)


if __name__ == "__main__":
cli_main()
41 changes: 33 additions & 8 deletions pl_examples/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import platform
from unittest import mock
import torch

import pytest
import torch

try:
from nvidia.dali import ops, types, pipeline, plugin
except (ImportError, ModuleNotFoundError):
DALI_AVAILABLE = False
else:
DALI_AVAILABLE = True

dp_16_args = """
--max_epochs 1 \
Expand Down Expand Up @@ -28,7 +37,7 @@
--precision 16 \
"""


# TODO
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# @pytest.mark.parametrize('cli_args', [dp_16_args])
# def test_examples_dp_mnist(cli_args):
Expand All @@ -38,15 +47,17 @@
# cli_main()


# TODO
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# @pytest.mark.parametrize('cli_args', [dp_16_args])
# def test_examples_dp_image_classifier(cli_args):
# from pl_examples.basic_examples.image_classifier import cli_main
#
# with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
# cli_main()
#
#


# TODO
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# @pytest.mark.parametrize('cli_args', [dp_16_args])
# def test_examples_dp_autoencoder(cli_args):
Expand All @@ -56,24 +67,27 @@
# cli_main()


# TODO
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# @pytest.mark.parametrize('cli_args', [ddp_args])
# def test_examples_ddp_mnist(cli_args):
# from pl_examples.basic_examples.mnist import cli_main
#
# with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
# cli_main()
#
#


# TODO
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# @pytest.mark.parametrize('cli_args', [ddp_args])
# def test_examples_ddp_image_classifier(cli_args):
# from pl_examples.basic_examples.image_classifier import cli_main
#
# with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
# cli_main()
#
#


# TODO
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# @pytest.mark.parametrize('cli_args', [ddp_args])
# def test_examples_ddp_autoencoder(cli_args):
Expand All @@ -92,3 +106,14 @@ def test_examples_cpu(cli_args):
for cli_cmd in [mnist_cli, ic_cli, ae_cli]:
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
cli_cmd()


@pytest.mark.skipif(not DALI_AVAILABLE, reason="Nvidia DALI required")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(platform.system() != 'Linux', reason='Only applies to Linux platform.')
@pytest.mark.parametrize('cli_args', [cpu_args])
def test_examples_mnist_dali(cli_args):
from pl_examples.basic_examples.mnist_dali import cli_main

with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
cli_main()
2 changes: 1 addition & 1 deletion requirements/examples.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torchvision>=0.4.1,<0.9.0
gym>=0.17.0
gym>=0.17.0

0 comments on commit 6e5f232

Please sign in to comment.