Skip to content

Commit

Permalink
Add implementation for DALI AA and TA with readme
Browse files Browse the repository at this point in the history
Adjust some configuration options to accomodate it.
Remove the obsolete pipeline

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki committed Mar 6, 2023
1 parent d5e56ca commit c792869
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 161 deletions.
1 change: 1 addition & 0 deletions docs/examples/use_cases/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"video_superres/README.rst",
"pytorch/resnet50/pytorch-resnet50.rst",
"pytorch/single_stage_detector/pytorch_ssd.rst",
"pytorch/efficientnet/readme.rst",
"tensorflow/resnet-n/README.rst",
"tensorflow/yolov4/readme.rst",
"tensorflow/efficientdet/README.rst",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvidia.dali import fn
from nvidia.dali import types

from nvidia.dali.pipeline.experimental import pipeline_def

from nvidia.dali.auto_aug import auto_augment, trivial_augment


@pipeline_def(enable_conditionals=True)
def training_pipe(data_dir, interpolation, image_size, automatic_augmentation, dali_device="gpu",
rank=0, world_size=1):
rng = fn.random.coin_flip(probability=0.5)

jpegs, labels = fn.readers.file(name="Reader", file_root=data_dir, shard_id=rank,
num_shards=world_size, random_shuffle=True, pad_last_batch=True)

if dali_device == "gpu":
decoder_device = "mixed"
rrc_device = "gpu"
else:
decoder_device = "cpu"
rrc_device = "cpu"

images = fn.decoders.image(jpegs, device=decoder_device, output_type=types.RGB,
device_memory_padding=211025920, host_memory_padding=140544512)

images = fn.random_resized_crop(images, device=rrc_device, size=[image_size, image_size],
interp_type=interpolation,
random_aspect_ratio=[0.75, 4.0 / 3.0], random_area=[0.08, 1.0],
num_attempts=100, antialias=False)

# Make sure that from this point we are processing on GPU regardless of dali_device parameter
images = images.gpu()

images = fn.flip(images, horizontal=rng)

# Based on the specification, apply the automatic augmentation policy. Note, that from the point
# of Pipeline definition, this `if` statement relies on static scalar parameter, so it is
# evaluated exactly once during build - we either include automatic augmentations or not.
if automatic_augmentation == "autoaugment":
shapes = fn.peek_image_shape(jpegs)
output = auto_augment.auto_augment_image_net(images, shapes)
elif automatic_augmentation == "trivialaugment":
output = trivial_augment.trivial_augment_wide(images)
else:
output = images

output = fn.crop_mirror_normalize(output, dtype=types.FLOAT, output_layout=types.NCHW,
crop=(image_size, image_size),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])

return output, labels


@pipeline_def
def validation_pipe(data_dir, interpolation, image_size, image_crop, rank=0, world_size=1):
jpegs, label = fn.readers.file(file_root=data_dir, shard_id=rank, num_shards=world_size,
random_shuffle=False, pad_last_batch=True)

images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB)

images = fn.resize(images, resize_shorter=image_size, interp_type=interpolation,
antialias=False)

output = fn.crop_mirror_normalize(images, dtype=types.FLOAT, output_layout=types.NCHW,
crop=(image_crop, image_crop),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
return output, label
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
import os
import torch
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from PIL import Image
from functools import partial

Expand All @@ -44,13 +42,17 @@
import nvidia.dali.ops as ops
import nvidia.dali.types as types

DATA_BACKEND_CHOICES.append("dali-gpu")
DATA_BACKEND_CHOICES.append("dali-cpu")
except ImportError:
from image_classification.dali import training_pipe, validation_pipe

DATA_BACKEND_CHOICES.append("dali")
except ImportError as e:
print(
"Please install DALI from https://www.github.com/NVIDIA/DALI to run this example."
)

# TODO(klecki): Move it back again
import torchvision.datasets as datasets
import torchvision.transforms as transforms

def load_jpeg_from_file(path, cuda=True):
img_transforms = transforms.Compose(
Expand All @@ -75,135 +77,6 @@ def load_jpeg_from_file(path, cuda=True):

return input


class HybridTrainPipe(Pipeline):
def __init__(
self,
batch_size,
num_threads,
device_id,
data_dir,
interpolation,
crop,
dali_cpu=False,
):
super(HybridTrainPipe, self).__init__(
batch_size, num_threads, device_id, seed=12 + device_id
)
interpolation = {
"bicubic": types.INTERP_CUBIC,
"bilinear": types.INTERP_LINEAR,
"triangular": types.INTERP_TRIANGULAR,
}[interpolation]
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
rank = 0
world_size = 1

self.input = ops.FileReader(
file_root=data_dir,
shard_id=rank,
num_shards=world_size,
random_shuffle=True,
pad_last_batch=True,
)

if dali_cpu:
dali_device = "cpu"
self.decode = ops.ImageDecoder(device=dali_device, output_type=types.RGB)
else:
dali_device = "gpu"
# This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
# without additional reallocations
self.decode = ops.ImageDecoder(
device="mixed",
output_type=types.RGB,
device_memory_padding=211025920,
host_memory_padding=140544512,
)

self.res = ops.RandomResizedCrop(
device=dali_device,
size=[crop, crop],
interp_type=interpolation,
random_aspect_ratio=[0.75, 4.0 / 3.0],
random_area=[0.08, 1.0],
num_attempts=100,
antialias=False,
)

self.cmnp = ops.CropMirrorNormalize(
device="gpu",
dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
)
self.coin = ops.CoinFlip(probability=0.5)

def define_graph(self):
rng = self.coin()
self.jpegs, self.labels = self.input(name="Reader")
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images.gpu(), mirror=rng)
return [output, self.labels]


class HybridValPipe(Pipeline):
def __init__(
self, batch_size, num_threads, device_id, data_dir, interpolation, crop, size
):
super(HybridValPipe, self).__init__(
batch_size, num_threads, device_id, seed=12 + device_id
)
interpolation = {
"bicubic": types.INTERP_CUBIC,
"bilinear": types.INTERP_LINEAR,
"triangular": types.INTERP_TRIANGULAR,
}[interpolation]
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
rank = 0
world_size = 1

self.input = ops.FileReader(
file_root=data_dir,
shard_id=rank,
num_shards=world_size,
random_shuffle=False,
pad_last_batch=True,
)

self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
self.res = ops.Resize(
device="gpu",
resize_shorter=size,
interp_type=interpolation,
antialias=False,
)
self.cmnp = ops.CropMirrorNormalize(
device="gpu",
dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
)

def define_graph(self):
self.jpegs, self.labels = self.input(name="Reader")
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images)
return [output, self.labels]


class DALIWrapper(object):
def gen_wrapper(dalipipeline, num_classes, one_hot, memory_format):
for data in dalipipeline:
Expand All @@ -226,15 +99,15 @@ def __iter__(self):
)


def get_dali_train_loader(dali_cpu=False):
def get_dali_train_loader(dali_device="gpu"):
def gdtl(
data_path,
image_size,
batch_size,
num_classes,
one_hot,
interpolation="bilinear",
augmentation=None,
augmentation="disabled",
start_epoch=0,
workers=5,
_worker_init_fn=None,
Expand All @@ -248,21 +121,24 @@ def gdtl(
rank = 0
world_size = 1

interpolation = {
"bicubic": types.INTERP_CUBIC,
"bilinear": types.INTERP_LINEAR,
"triangular": types.INTERP_TRIANGULAR,
}[interpolation]

traindir = os.path.join(data_path, "train")
if augmentation is not None:
raise NotImplementedError(
f"Augmentation {augmentation} for dali loader is not supported"
)

pipe = HybridTrainPipe(
batch_size=batch_size,
num_threads=workers,
device_id=rank % torch.cuda.device_count(),
data_dir=traindir,
interpolation=interpolation,
crop=image_size,
dali_cpu=dali_cpu,
)

pipeline_kwargs = {
"batch_size" : batch_size,
"num_threads" : workers,
"device_id" : rank % torch.cuda.device_count(),
"seed": 12 + rank % torch.cuda.device_count(),
}

pipe = training_pipe(data_dir=traindir, interpolation=interpolation, image_size=image_size,
dali_device=dali_device, rank=rank, world_size=world_size,
**pipeline_kwargs)

pipe.build()
train_loader = DALIClassificationIterator(
Expand Down Expand Up @@ -298,17 +174,24 @@ def gdvl(
rank = 0
world_size = 1

interpolation = {
"bicubic": types.INTERP_CUBIC,
"bilinear": types.INTERP_LINEAR,
"triangular": types.INTERP_TRIANGULAR,
}[interpolation]

valdir = os.path.join(data_path, "val")

pipe = HybridValPipe(
batch_size=batch_size,
num_threads=workers,
device_id=rank % torch.cuda.device_count(),
data_dir=valdir,
interpolation=interpolation,
crop=image_size,
size=image_size + crop_padding,
)
pipeline_kwargs = {
"batch_size" : batch_size,
"num_threads" : workers,
"device_id" : rank % torch.cuda.device_count(),
"seed": 12 + rank % torch.cuda.device_count(),
}

pipe = validation_pipe(data_dir=valdir, interpolation=interpolation,
image_size=image_size + crop_padding, image_crop=image_size,
**pipeline_kwargs)

pipe.build()
val_loader = DALIClassificationIterator(
Expand Down Expand Up @@ -430,8 +313,13 @@ def get_pytorch_train_loader(
transforms.RandomResizedCrop(image_size, interpolation=interpolation),
transforms.RandomHorizontalFlip(),
]
if augmentation == "autoaugment":
if augmentation == "disabled":
pass
elif augmentation == "autoaugment":
transforms_list.append(AutoaugmentImageNetPolicy())
else:
raise NotImplementedError(f"Automatic augmentation: '{augmentation}' is not supported"
" for PyTorch data loader.")
train_dataset = datasets.ImageFolder(traindir, transforms.Compose(transforms_list))

if torch.distributed.is_initialized():
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/use_cases/pytorch/efficientnet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def add_parser_arguments(parser, skip_arch=False):
parser.add_argument(
"--memory-format",
type=str,
default="nchw",
default="nhwc",
choices=["nchw", "nhwc"],
help="memory layout, nchw or nhwc",
)
Expand Down
Loading

0 comments on commit c792869

Please sign in to comment.